diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index 61c8d14039db..03849995735d 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -31,7 +31,7 @@ RUN apt-get install -y curl wget gnupg python3 python-is-python3 python3-pip git RUN python -m pip install \ pip==23.3.1 \ setuptools==68.2.2 \ - poetry==1.5.1 + poetry==1.7.1 USER $USERNAME ENV PATH="/home/$USERNAME/.local/bin:${PATH}" diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 71a8aea59859..8dac63a20598 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -3,6 +3,9 @@ # Default code owners * @danieljanes @tanertopal +# README.md +README.md @jafermarq @tanertopal @danieljanes + # Flower Baselines /baselines @jafermarq @tanertopal @danieljanes diff --git a/.github/ISSUE_TEMPLATE/baseline_request.yml b/.github/ISSUE_TEMPLATE/baseline_request.yml index 49ae922a94ad..8df8eec22a21 100644 --- a/.github/ISSUE_TEMPLATE/baseline_request.yml +++ b/.github/ISSUE_TEMPLATE/baseline_request.yml @@ -38,18 +38,18 @@ body: attributes: label: For first time contributors value: | - - [ ] Read the [`first contribution` doc](https://flower.dev/docs/first-time-contributors.html) + - [ ] Read the [`first contribution` doc](https://flower.ai/docs/first-time-contributors.html) - [ ] Complete the Flower tutorial - [ ] Read the Flower Baselines docs to get an overview: - - [ ] [How to use Flower Baselines](https://flower.dev/docs/baselines/how-to-use-baselines.html) - - [ ] [How to contribute a Flower Baseline](https://flower.dev/docs/baselines/how-to-contribute-baselines.html) + - [ ] [How to use Flower Baselines](https://flower.ai/docs/baselines/how-to-use-baselines.html) + - [ ] [How to contribute a Flower Baseline](https://flower.ai/docs/baselines/how-to-contribute-baselines.html) - type: checkboxes attributes: label: Prepare - understand the scope options: - label: Read the paper linked above - label: Decide which experiments you'd like to reproduce. The more the better! - - label: Follow the steps outlined in [Add a new Flower Baseline](https://flower.dev/docs/baselines/how-to-contribute-baselines.html#add-a-new-flower-baseline). + - label: Follow the steps outlined in [Add a new Flower Baseline](https://flower.ai/docs/baselines/how-to-contribute-baselines.html#add-a-new-flower-baseline). - label: You can use as reference [other baselines](https://github.com/adap/flower/tree/main/baselines) that the community merged following those steps. - type: checkboxes attributes: diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index d0ab46d60b57..9e0ed23d4dbb 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -1,14 +1,14 @@ blank_issues_enabled: false contact_links: - name: Slack Channel - url: https://flower.dev/join-slack + url: https://flower.ai/join-slack about: Connect with other Flower users and contributors and discuss with them or ask them questions. - name: Discussion url: https://github.com/adap/flower/discussions - about: Ask about new features or general questions. Please use the discussion area in most of the cases instead of the issues. + about: Ask about new features or general questions. Please use the discussion area in most of the cases instead of the issues. - name: Flower Issues url: https://github.com/adap/flower/issues about: Contribute new features/enhancements, report bugs, or improve the documentation. - name: Flower Mail - url: https://flower.dev/ - about: If your project needs professional support please contact the Flower team (hello@flower.dev). + url: https://flower.ai/ + about: If your project needs professional support please contact the Flower team (hello@flower.ai). diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 479f88c1bbd5..0077bbab0909 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -36,15 +36,15 @@ Example: The variable `rnd` was renamed to `server_round` to improve readability - [ ] Implement proposed change - [ ] Write tests -- [ ] Update [documentation](https://flower.dev/docs/writing-documentation.html) +- [ ] Update [documentation](https://flower.ai/docs/writing-documentation.html) - [ ] Update the changelog entry below - [ ] Make CI checks pass -- [ ] Ping maintainers on [Slack](https://flower.dev/join-slack/) (channel `#contributions`) +- [ ] Ping maintainers on [Slack](https://flower.ai/join-slack/) (channel `#contributions`) diff --git a/.github/actions/bootstrap/action.yml b/.github/actions/bootstrap/action.yml index bab5113bc567..7b1716bbc954 100644 --- a/.github/actions/bootstrap/action.yml +++ b/.github/actions/bootstrap/action.yml @@ -12,7 +12,7 @@ inputs: default: 68.2.2 poetry-version: description: "Version of poetry to be installed using pip" - default: 1.5.1 + default: 1.7.1 outputs: python-version: description: "Version range or exact version of Python or PyPy" @@ -30,7 +30,7 @@ runs: using: "composite" steps: - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ inputs.python-version }} - name: Install build tools diff --git a/.github/workflows/_docker-build.yml b/.github/workflows/_docker-build.yml index 4a1289d9175a..99ec0671db66 100644 --- a/.github/workflows/_docker-build.yml +++ b/.github/workflows/_docker-build.yml @@ -114,7 +114,7 @@ jobs: metadata: ${{ steps.meta.outputs.json }} steps: - name: Download digests - uses: actions/download-artifact@6b208ae046db98c579e8a3aa621ab581ff575935 # v4.1.1 + uses: actions/download-artifact@eaceaf801fd36c7dee90939fad912460b18a1ffe # v4.1.2 with: pattern: digests-${{ needs.build.outputs.build-id }}-* path: /tmp/digests diff --git a/.github/workflows/baselines.yml b/.github/workflows/baselines.yml index bfb26053836d..c4485fe72d10 100644 --- a/.github/workflows/baselines.yml +++ b/.github/workflows/baselines.yml @@ -1,14 +1,5 @@ name: Baselines -# The aim of this workflow is to test only the changed (or added) baseline. -# Here is the rough idea of how it works (more details are presented later in the comments): -# 1. Checks for the changes between the current branch and the main - in case of PR - -# or between the HEAD and HEAD~1 (main last commit and the previous one) - in case of -# a push to main. -# 2. Fails the test if there are changes to more than one baseline. Passes the test -# (skips the rests) if there are no changes to any baselines. Follows the test if only -# one baseline is added or modified. -# 3. Sets up the env specified for the baseline. -# 4. Runs the tests. + on: push: branches: @@ -24,112 +15,75 @@ concurrency: env: FLWR_TELEMETRY_ENABLED: 0 -defaults: - run: - working-directory: baselines - jobs: - test_baselines: - name: Test + changes: runs-on: ubuntu-22.04 + permissions: + pull-requests: read + outputs: + baselines: ${{ steps.filter.outputs.changes }} steps: - uses: actions/checkout@v4 - # The depth two of the checkout is needed in case of merging to the main - # because we compare the HEAD (current version) with HEAD~1 (version before - # the PR was merged) - with: - fetch-depth: 2 - - name: Fetch main branch - run: | - # The main branch is needed in case of the PR to make a comparison (by - # default the workflow takes as little information as possible - it does not - # have the history - if [ ${{ github.event_name }} == "pull_request" ] - then - git fetch origin main:main - fi - - name: Find changed/new baselines - id: find_changed_baselines_dirs + + - shell: bash run: | - if [ ${{ github.event_name }} == "push" ] - then - # Push event triggered when merging to main - change_references="HEAD..HEAD~1" - else - # Pull request event triggered for any commit to a pull request - change_references="main..HEAD" - fi - dirs=$(git diff --dirstat=files,0 ${change_references} . | awk '{print $2}' | grep -E '^baselines/[^/]*/$' | \ - grep -v \ - -e '^baselines/dev' \ - -e '^baselines/baseline_template' \ - -e '^baselines/flwr_baselines' \ - -e '^baselines/doc' \ - | sed 's/^baselines\///') - # git diff --dirstat=files,0 ${change_references} . - checks the differences - # and a file is counted as changed if more than 0 lines were changed - # it returns the results in the format x.y% path/to/dir/ - # awk '{print $2}' - takes only the directories (skips the percentages) - # grep -E '^baselines/[^/]*/$' - takes only the paths that start with - # baseline (and have at least one subdirectory) - # grep -v -e ... - excludes the `baseline_template`, `dev`, `flwr_baselines` - # sed 's/^baselines\///' - narrows down the path to baseline/ - echo "Detected changed directories: ${dirs}" - # Save changed dirs to output of this step - EOF=$(dd if=/dev/urandom bs=15 count=1 status=none | base64) - echo "dirs<> "$GITHUB_OUTPUT" - for dir in $dirs - do - echo "$dir" >> "$GITHUB_OUTPUT" - done - echo "EOF" >> "$GITHUB_OUTPUT" - - name: Validate changed/new baselines - id: validate_changed_baselines_dirs + # create a list of all directories in baselines + { + echo 'FILTER_PATHS<> "$GITHUB_ENV" + + - uses: dorny/paths-filter@v3 + id: filter + with: + filters: ${{ env.FILTER_PATHS }} + + - if: ${{ github.event.pull_request.head.repo.fork }} run: | - dirs="${{ steps.find_changed_baselines_dirs.outputs.dirs }}" - dirs_array=() - if [[ -n $dirs ]]; then - while IFS= read -r line; do - dirs_array+=("$line") - done <<< "$dirs" - fi - length=${#dirs_array[@]} - echo "The number of changed baselines is $length" - - if [ $length -gt 1 ]; then - echo "The changes should only apply to a single baseline" - exit 1 + CHANGES=$(echo "${{ toJson(steps.filter.outputs.changes) }}" | jq '. | length') + if [ "$CHANGES" -gt 1 ]; then + echo "::error ::The changes should only apply to a single baseline." + exit 1 fi - - if [ $length -eq 0 ]; then - echo "The baselines were not changed - skipping the remaining steps." - echo "baseline_changed=false" >> "$GITHUB_OUTPUT" - exit 0 - fi - - echo "changed_dir=${dirs[0]}" >> "$GITHUB_OUTPUT" - echo "baseline_changed=true" >> "$GITHUB_OUTPUT" + + test: + runs-on: ubuntu-22.04 + needs: changes + if: ${{ needs.changes.outputs.baselines != '' && toJson(fromJson(needs.changes.outputs.baselines)) != '[]' }} + strategy: + matrix: + baseline: ${{ fromJSON(needs.changes.outputs.baselines) }} + steps: + - uses: actions/checkout@v4 + - name: Bootstrap - if: steps.validate_changed_baselines_dirs.outputs.baseline_changed == 'true' uses: ./.github/actions/bootstrap with: python-version: '3.10' + - name: Install dependencies - if: steps.validate_changed_baselines_dirs.outputs.baseline_changed == 'true' - run: | - changed_dir="${{ steps.validate_changed_baselines_dirs.outputs.changed_dir }}" - cd "${changed_dir}" - python -m poetry install - - name: Test - if: steps.validate_changed_baselines_dirs.outputs.baseline_changed == 'true' - run: | - dir="${{ steps.validate_changed_baselines_dirs.outputs.changed_dir }}" - echo "Testing ${dir}" - ./dev/test-baseline.sh $dir - - name: Test Structure - if: steps.validate_changed_baselines_dirs.outputs.baseline_changed == 'true' - run: | - dir="${{ steps.validate_changed_baselines_dirs.outputs.changed_dir }}" - echo "Testing ${dir}" - ./dev/test-baseline-structure.sh $dir + working-directory: baselines/${{ matrix.baseline }} + run: python -m poetry install + + - name: Testing ${{ matrix.baseline }} + working-directory: baselines + run: ./dev/test-baseline.sh ${{ matrix.baseline }} + - name: Test Structure of ${{ matrix.baseline }} + working-directory: baselines + run: ./dev/test-baseline-structure.sh ${{ matrix.baseline }} diff --git a/.github/workflows/cpp.yml b/.github/workflows/cpp.yml index a1d918aa5b29..7efc879bcee2 100644 --- a/.github/workflows/cpp.yml +++ b/.github/workflows/cpp.yml @@ -19,13 +19,13 @@ jobs: uses: ./.github/actions/bootstrap - name: Cache restore SDK build - uses: actions/cache/restore@v3 + uses: actions/cache/restore@v4 with: path: build/ key: ${{ runner.os }}-sdk-build - name: Cache restore example build - uses: actions/cache/restore@v3 + uses: actions/cache/restore@v4 with: path: examples/quickstart-cpp/build/ key: ${{ runner.os }}-example-build @@ -82,14 +82,14 @@ jobs: fi - name: Cache save SDK build - uses: actions/cache/save@v3 + uses: actions/cache/save@v4 if: github.ref_name == 'main' with: path: build/ key: ${{ runner.os }}-sdk-build - name: Cache save example build - uses: actions/cache/save@v3 + uses: actions/cache/save@v4 if: github.ref_name == 'main' with: path: examples/quickstart-cpp/build/ diff --git a/.github/workflows/docker-client.yml b/.github/workflows/docker-client.yml index 47083b258982..3c2d83596733 100644 --- a/.github/workflows/docker-client.yml +++ b/.github/workflows/docker-client.yml @@ -4,7 +4,7 @@ on: workflow_dispatch: inputs: flwr-version: - description: "Version of Flower e.g. (1.6.0)." + description: "Version of Flower e.g. (1.7.0)." required: true type: string diff --git a/.github/workflows/docker-server.yml b/.github/workflows/docker-server.yml index f580a8e9a280..1e43715207d4 100644 --- a/.github/workflows/docker-server.yml +++ b/.github/workflows/docker-server.yml @@ -4,7 +4,7 @@ on: workflow_dispatch: inputs: flwr-version: - description: "Version of Flower e.g. (1.6.0)." + description: "Version of Flower e.g. (1.7.0)." required: true type: string base-image-tag: diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 94f3495a20ef..a4c769fdf850 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -43,8 +43,9 @@ jobs: AWS_DEFAULT_REGION: ${{ secrets. AWS_DEFAULT_REGION }} AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets. AWS_SECRET_ACCESS_KEY }} + DOCS_BUCKET: flower.ai run: | - aws s3 sync --delete --exclude ".*" --exclude "v/*" --cache-control "no-cache" ./doc/build/html/ s3://flower.dev/docs/framework - aws s3 sync --delete --exclude ".*" --exclude "v/*" --cache-control "no-cache" ./baselines/doc/build/html/ s3://flower.dev/docs/baselines - aws s3 sync --delete --exclude ".*" --exclude "v/*" --cache-control "no-cache" ./examples/doc/build/html/ s3://flower.dev/docs/examples - aws s3 sync --delete --exclude ".*" --exclude "v/*" --cache-control "no-cache" ./datasets/doc/build/html/ s3://flower.dev/docs/datasets + aws s3 sync --delete --exclude ".*" --exclude "v/*" --cache-control "no-cache" ./doc/build/html/ s3://${{ env.DOCS_BUCKET }}/docs/framework + aws s3 sync --delete --exclude ".*" --exclude "v/*" --cache-control "no-cache" ./baselines/doc/build/html/ s3://${{ env.DOCS_BUCKET }}/docs/baselines + aws s3 sync --delete --exclude ".*" --exclude "v/*" --cache-control "no-cache" ./examples/doc/build/html/ s3://${{ env.DOCS_BUCKET }}/docs/examples + aws s3 sync --delete --exclude ".*" --exclude "v/*" --cache-control "no-cache" ./datasets/doc/build/html/ s3://${{ env.DOCS_BUCKET }}/docs/datasets diff --git a/.github/workflows/e2e.yml b/.github/workflows/e2e.yml index bf3843c25bd4..db9f65a4f4f3 100644 --- a/.github/workflows/e2e.yml +++ b/.github/workflows/e2e.yml @@ -14,6 +14,7 @@ concurrency: env: FLWR_TELEMETRY_ENABLED: 0 + ARTIFACT_BUCKET: artifact.flower.ai jobs: wheel: @@ -43,7 +44,7 @@ jobs: echo "SHORT_SHA=$sha_short" >> "$GITHUB_OUTPUT" [ -z "${{ github.head_ref }}" ] && dir="${{ github.ref_name }}" || dir="pr/${{ github.head_ref }}" echo "DIR=$dir" >> "$GITHUB_OUTPUT" - aws s3 cp --content-disposition "attachment" --cache-control "no-cache" ./ s3://artifact.flower.dev/py/$dir/$sha_short --recursive + aws s3 cp --content-disposition "attachment" --cache-control "no-cache" ./ s3://${{ env.ARTIFACT_BUCKET }}/py/$dir/$sha_short --recursive outputs: whl_path: ${{ steps.upload.outputs.WHL_PATH }} short_sha: ${{ steps.upload.outputs.SHORT_SHA }} @@ -89,11 +90,6 @@ jobs: from torchvision.datasets import MNIST MNIST('./data', download=True) - - directory: mxnet - dataset: | - import mxnet as mx - mx.test_utils.get_mnist() - - directory: scikit-learn dataset: | import openml @@ -128,7 +124,7 @@ jobs: - name: Install Flower wheel from artifact store if: ${{ github.repository == 'adap/flower' && !github.event.pull_request.head.repo.fork && github.actor != 'dependabot[bot]' }} run: | - python -m pip install https://artifact.flower.dev/py/${{ needs.wheel.outputs.dir }}/${{ needs.wheel.outputs.short_sha }}/${{ needs.wheel.outputs.whl_path }} + python -m pip install https://${{ env.ARTIFACT_BUCKET }}/py/${{ needs.wheel.outputs.dir }}/${{ needs.wheel.outputs.short_sha }}/${{ needs.wheel.outputs.whl_path }} - name: Download dataset if: ${{ matrix.dataset }} run: python -c "${{ matrix.dataset }}" @@ -138,6 +134,12 @@ jobs: run: python simulation.py - name: Run driver test run: ./../test_driver.sh "${{ matrix.directory }}" + - name: Run driver test with REST + if: ${{ matrix.directory == 'bare' }} + run: ./../test_driver.sh bare rest + - name: Run driver test with SQLite database + if: ${{ matrix.directory == 'bare' }} + run: ./../test_driver.sh bare sqlite strategies: runs-on: ubuntu-22.04 @@ -163,9 +165,9 @@ jobs: - name: Install Flower wheel from artifact store if: ${{ github.repository == 'adap/flower' && !github.event.pull_request.head.repo.fork && github.actor != 'dependabot[bot]' }} run: | - python -m pip install https://artifact.flower.dev/py/${{ needs.wheel.outputs.dir }}/${{ needs.wheel.outputs.short_sha }}/${{ needs.wheel.outputs.whl_path }} + python -m pip install https://${{ env.ARTIFACT_BUCKET }}/py/${{ needs.wheel.outputs.dir }}/${{ needs.wheel.outputs.short_sha }}/${{ needs.wheel.outputs.whl_path }} - name: Cache Datasets - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: "~/.keras" key: keras-datasets diff --git a/.github/workflows/framework-draft-release.yml b/.github/workflows/framework-draft-release.yml index 959d17249765..91a89953cf96 100644 --- a/.github/workflows/framework-draft-release.yml +++ b/.github/workflows/framework-draft-release.yml @@ -5,6 +5,9 @@ on: tags: - "v*.*.*" +env: + ARTIFACT_BUCKET: artifact.flower.ai + jobs: publish: if: ${{ github.repository == 'adap/flower' }} @@ -26,16 +29,16 @@ jobs: run: | tag_name=$(echo "${GITHUB_REF_NAME}" | cut -c2-) echo "TAG_NAME=$tag_name" >> "$GITHUB_ENV" - + wheel_name="flwr-${tag_name}-py3-none-any.whl" echo "WHEEL_NAME=$wheel_name" >> "$GITHUB_ENV" - + tar_name="flwr-${tag_name}.tar.gz" echo "TAR_NAME=$tar_name" >> "$GITHUB_ENV" - wheel_url="https://artifact.flower.dev/py/main/${GITHUB_SHA::7}/${wheel_name}" - tar_url="https://artifact.flower.dev/py/main/${GITHUB_SHA::7}/${tar_name}" - + wheel_url="https://${{ env.ARTIFACT_BUCKET }}/py/main/${GITHUB_SHA::7}/${wheel_name}" + tar_url="https://${{ env.ARTIFACT_BUCKET }}/py/main/${GITHUB_SHA::7}/${tar_name}" + curl $wheel_url --output $wheel_name curl $tar_url --output $tar_name - name: Upload wheel @@ -44,14 +47,14 @@ jobs: AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets. AWS_SECRET_ACCESS_KEY }} run: | - aws s3 cp --content-disposition "attachment" --cache-control "no-cache" ./${{ env.WHEEL_NAME }} s3://artifact.flower.dev/py/release/v${{ env.TAG_NAME }}/${{ env.WHEEL_NAME }} - aws s3 cp --content-disposition "attachment" --cache-control "no-cache" ./${{ env.TAR_NAME }} s3://artifact.flower.dev/py/release/v${{ env.TAG_NAME }}/${{ env.TAR_NAME }} - + aws s3 cp --content-disposition "attachment" --cache-control "no-cache" ./${{ env.WHEEL_NAME }} s3://${{ env.ARTIFACT_BUCKET }}/py/release/v${{ env.TAG_NAME }}/${{ env.WHEEL_NAME }} + aws s3 cp --content-disposition "attachment" --cache-control "no-cache" ./${{ env.TAR_NAME }} s3://${{ env.ARTIFACT_BUCKET }}/py/release/v${{ env.TAG_NAME }}/${{ env.TAR_NAME }} + - name: Generate body run: | ./dev/get-latest-changelog.sh > body.md cat body.md - + - name: Release uses: softprops/action-gh-release@v1 with: diff --git a/.github/workflows/framework-release.yml b/.github/workflows/framework-release.yml index 0f3cda8abae3..04b68fd38af9 100644 --- a/.github/workflows/framework-release.yml +++ b/.github/workflows/framework-release.yml @@ -3,11 +3,14 @@ name: Publish `flwr` release on PyPI on: release: types: [released] - + concurrency: group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.event.pull_request.number || github.ref }} cancel-in-progress: true - + +env: + ARTIFACT_BUCKET: artifact.flower.ai + jobs: publish: if: ${{ github.repository == 'adap/flower' }} @@ -28,13 +31,15 @@ jobs: run: | TAG_NAME=$(echo "${GITHUB_REF_NAME}" | cut -c2-) - wheel_name="flwr-${TAG_NAME}-py3-none-any.whl" + wheel_name="flwr-${TAG_NAME}-py3-none-any.whl" tar_name="flwr-${TAG_NAME}.tar.gz" - wheel_url="https://artifact.flower.dev/py/release/v${TAG_NAME}/${wheel_name}" - tar_url="https://artifact.flower.dev/py/release/v${TAG_NAME}/${tar_name}" + wheel_url="https://${{ env.ARTIFACT_BUCKET }}/py/release/v${TAG_NAME}/${wheel_name}" + tar_url="https://${{ env.ARTIFACT_BUCKET }}/py/release/v${TAG_NAME}/${tar_name}" + + mkdir -p dist - curl $wheel_url --output $wheel_name - curl $tar_url --output $tar_name + curl $wheel_url --output dist/$wheel_name + curl $tar_url --output dist/$tar_name python -m poetry publish -u __token__ -p ${{ secrets.PYPI_TOKEN }} diff --git a/CHANGELOG.md b/CHANGELOG.md index 264548c87669..1f01c0e9717d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,5 @@ # Changelog -Flower changes are tracked as part of the documentation: [Flower Changelog](https://flower.dev/docs/changelog.html). +Flower changes are tracked as part of the documentation: [Flower Changelog](https://flower.ai/docs/changelog.html). The changelog source can be edited here: `doc/source/changelog.rst` diff --git a/README.md b/README.md index 750b5cdb4b93..90faa2358fa6 100644 --- a/README.md +++ b/README.md @@ -1,16 +1,16 @@ # Flower: A Friendly Federated Learning Framework

- - Flower Website + + Flower Website

- Website | - Blog | - Docs | - Conference | - Slack + Website | + Blog | + Docs | + Conference | + Slack

@@ -18,7 +18,7 @@ [![PRs Welcome](https://img.shields.io/badge/PRs-welcome-brightgreen.svg)](https://github.com/adap/flower/blob/main/CONTRIBUTING.md) ![Build](https://github.com/adap/flower/actions/workflows/framework.yml/badge.svg) [![Downloads](https://static.pepy.tech/badge/flwr)](https://pepy.tech/project/flwr) -[![Slack](https://img.shields.io/badge/Chat-Slack-red)](https://flower.dev/join-slack) +[![Slack](https://img.shields.io/badge/Chat-Slack-red)](https://flower.ai/join-slack) Flower (`flwr`) is a framework for building federated learning systems. The design of Flower is based on a few guiding principles: @@ -33,14 +33,13 @@ design of Flower is based on a few guiding principles: - **Framework-agnostic**: Different machine learning frameworks have different strengths. Flower can be used with any machine learning framework, for - example, [PyTorch](https://pytorch.org), - [TensorFlow](https://tensorflow.org), [Hugging Face Transformers](https://huggingface.co/), [PyTorch Lightning](https://pytorchlightning.ai/), [scikit-learn](https://scikit-learn.org/), [JAX](https://jax.readthedocs.io/), [TFLite](https://tensorflow.org/lite/), [fastai](https://www.fast.ai/), [Pandas](https://pandas.pydata.org/) for federated analytics, or even raw [NumPy](https://numpy.org/) + example, [PyTorch](https://pytorch.org), [TensorFlow](https://tensorflow.org), [Hugging Face Transformers](https://huggingface.co/), [PyTorch Lightning](https://pytorchlightning.ai/), [scikit-learn](https://scikit-learn.org/), [JAX](https://jax.readthedocs.io/), [TFLite](https://tensorflow.org/lite/), [MONAI](https://docs.monai.io/en/latest/index.html), [fastai](https://www.fast.ai/), [MLX](https://ml-explore.github.io/mlx/build/html/index.html), [XGBoost](https://xgboost.readthedocs.io/en/stable/), [Pandas](https://pandas.pydata.org/) for federated analytics, or even raw [NumPy](https://numpy.org/) for users who enjoy computing gradients by hand. - **Understandable**: Flower is written with maintainability in mind. The community is encouraged to both read and contribute to the codebase. -Meet the Flower community on [flower.dev](https://flower.dev)! +Meet the Flower community on [flower.ai](https://flower.ai)! ## Federated Learning Tutorial @@ -56,11 +55,11 @@ Flower's goal is to make federated learning accessible to everyone. This series 2. **Using Strategies in Federated Learning** - [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adap/flower/blob/main/doc/source/tutorial-use-a-federated-learning-strategy-pytorch.ipynb) (or open the [Jupyter Notebook](https://github.com/adap/flower/blob/main/doc/source/tutorial-use-a-federated-learning-strategy-pytorch.ipynb)) + [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adap/flower/blob/main/doc/source/tutorial-series-use-a-federated-learning-strategy-pytorch.ipynb) (or open the [Jupyter Notebook](https://github.com/adap/flower/blob/main/doc/source/tutorial-series-use-a-federated-learning-strategy-pytorch.ipynb)) 3. **Building Strategies for Federated Learning** - [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adap/flower/blob/main/doc/source/tutorial-series-use-a-federated-learning-strategy-pytorch.ipynb) (or open the [Jupyter Notebook](https://github.com/adap/flower/blob/main/doc/source/tutorial-series-use-a-federated-learning-strategy-pytorch.ipynb)) + [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adap/flower/blob/main/doc/source/tutorial-series-build-a-strategy-from-scratch-pytorch.ipynb) (or open the [Jupyter Notebook](https://github.com/adap/flower/blob/main/doc/source/tutorial-series-build-a-strategy-from-scratch-pytorch.ipynb)) 4. **Custom Clients for Federated Learning** @@ -74,19 +73,19 @@ Stay tuned, more tutorials are coming soon. Topics include **Privacy and Securit ## Documentation -[Flower Docs](https://flower.dev/docs): +[Flower Docs](https://flower.ai/docs): -- [Installation](https://flower.dev/docs/framework/how-to-install-flower.html) -- [Quickstart (TensorFlow)](https://flower.dev/docs/framework/tutorial-quickstart-tensorflow.html) -- [Quickstart (PyTorch)](https://flower.dev/docs/framework/tutorial-quickstart-pytorch.html) -- [Quickstart (Hugging Face)](https://flower.dev/docs/framework/tutorial-quickstart-huggingface.html) -- [Quickstart (PyTorch Lightning [code example])](https://flower.dev/docs/framework/tutorial-quickstart-pytorch-lightning.html) -- [Quickstart (Pandas)](https://flower.dev/docs/framework/tutorial-quickstart-pandas.html) -- [Quickstart (fastai)](https://flower.dev/docs/framework/tutorial-quickstart-fastai.html) -- [Quickstart (JAX)](https://flower.dev/docs/framework/tutorial-quickstart-jax.html) -- [Quickstart (scikit-learn)](https://flower.dev/docs/framework/tutorial-quickstart-scikitlearn.html) -- [Quickstart (Android [TFLite])](https://flower.dev/docs/framework/tutorial-quickstart-android.html) -- [Quickstart (iOS [CoreML])](https://flower.dev/docs/framework/tutorial-quickstart-ios.html) +- [Installation](https://flower.ai/docs/framework/how-to-install-flower.html) +- [Quickstart (TensorFlow)](https://flower.ai/docs/framework/tutorial-quickstart-tensorflow.html) +- [Quickstart (PyTorch)](https://flower.ai/docs/framework/tutorial-quickstart-pytorch.html) +- [Quickstart (Hugging Face)](https://flower.ai/docs/framework/tutorial-quickstart-huggingface.html) +- [Quickstart (PyTorch Lightning)](https://flower.ai/docs/framework/tutorial-quickstart-pytorch-lightning.html) +- [Quickstart (Pandas)](https://flower.ai/docs/framework/tutorial-quickstart-pandas.html) +- [Quickstart (fastai)](https://flower.ai/docs/framework/tutorial-quickstart-fastai.html) +- [Quickstart (JAX)](https://flower.ai/docs/framework/tutorial-quickstart-jax.html) +- [Quickstart (scikit-learn)](https://flower.ai/docs/framework/tutorial-quickstart-scikitlearn.html) +- [Quickstart (Android [TFLite])](https://flower.ai/docs/framework/tutorial-quickstart-android.html) +- [Quickstart (iOS [CoreML])](https://flower.ai/docs/framework/tutorial-quickstart-ios.html) ## Flower Baselines @@ -99,17 +98,24 @@ Flower Baselines is a collection of community-contributed projects that reproduc - [FedMLB](https://github.com/adap/flower/tree/main/baselines/fedmlb) - [FedPer](https://github.com/adap/flower/tree/main/baselines/fedper) - [FedProx](https://github.com/adap/flower/tree/main/baselines/fedprox) +- [FedNova](https://github.com/adap/flower/tree/main/baselines/fednova) +- [HeteroFL](https://github.com/adap/flower/tree/main/baselines/heterofl) +- [FedAvgM](https://github.com/adap/flower/tree/main/baselines/fedavgm) +- [FedStar](https://github.com/adap/flower/tree/main/baselines/fedstar) - [FedWav2vec2](https://github.com/adap/flower/tree/main/baselines/fedwav2vec2) - [FjORD](https://github.com/adap/flower/tree/main/baselines/fjord) - [MOON](https://github.com/adap/flower/tree/main/baselines/moon) - [niid-Bench](https://github.com/adap/flower/tree/main/baselines/niid_bench) - [TAMUNA](https://github.com/adap/flower/tree/main/baselines/tamuna) +- [FedVSSL](https://github.com/adap/flower/tree/main/baselines/fedvssl) +- [FedXGBoost](https://github.com/adap/flower/tree/main/baselines/hfedxgboost) +- [FedPara](https://github.com/adap/flower/tree/main/baselines/fedpara) - [FedAvg](https://github.com/adap/flower/tree/main/baselines/flwr_baselines/flwr_baselines/publications/fedavg_mnist) - [FedOpt](https://github.com/adap/flower/tree/main/baselines/flwr_baselines/flwr_baselines/publications/adaptive_federated_optimization) -Please refer to the [Flower Baselines Documentation](https://flower.dev/docs/baselines/) for a detailed categorization of baselines and for additional info including: -* [How to use Flower Baselines](https://flower.dev/docs/baselines/how-to-use-baselines.html) -* [How to contribute a new Flower Baseline](https://flower.dev/docs/baselines/how-to-contribute-baselines.html) +Please refer to the [Flower Baselines Documentation](https://flower.ai/docs/baselines/) for a detailed categorization of baselines and for additional info including: +* [How to use Flower Baselines](https://flower.ai/docs/baselines/how-to-use-baselines.html) +* [How to contribute a new Flower Baseline](https://flower.ai/docs/baselines/how-to-contribute-baselines.html) ## Flower Usage Examples @@ -124,21 +130,31 @@ Quickstart examples: - [Quickstart (fastai)](https://github.com/adap/flower/tree/main/examples/quickstart-fastai) - [Quickstart (Pandas)](https://github.com/adap/flower/tree/main/examples/quickstart-pandas) - [Quickstart (JAX)](https://github.com/adap/flower/tree/main/examples/quickstart-jax) +- [Quickstart (MONAI)](https://github.com/adap/flower/tree/main/examples/quickstart-monai) - [Quickstart (scikit-learn)](https://github.com/adap/flower/tree/main/examples/sklearn-logreg-mnist) - [Quickstart (Android [TFLite])](https://github.com/adap/flower/tree/main/examples/android) - [Quickstart (iOS [CoreML])](https://github.com/adap/flower/tree/main/examples/ios) +- [Quickstart (MLX)](https://github.com/adap/flower/tree/main/examples/quickstart-mlx) +- [Quickstart (XGBoost)](https://github.com/adap/flower/tree/main/examples/xgboost-quickstart) Other [examples](https://github.com/adap/flower/tree/main/examples): - [Raspberry Pi & Nvidia Jetson Tutorial](https://github.com/adap/flower/tree/main/examples/embedded-devices) - [PyTorch: From Centralized to Federated](https://github.com/adap/flower/tree/main/examples/pytorch-from-centralized-to-federated) +- [Vertical FL](https://github.com/adap/flower/tree/main/examples/vertical-fl) +- [Federated Finetuning of OpenAI's Whisper](https://github.com/adap/flower/tree/main/examples/whisper-federated-finetuning) +- [Federated Finetuning of Large Language Model](https://github.com/adap/flower/tree/main/examples/fedllm-finetune) +- [Federated Finetuning of a Vision Transformer](https://github.com/adap/flower/tree/main/examples/vit-finetune) - [Advanced Flower with TensorFlow/Keras](https://github.com/adap/flower/tree/main/examples/advanced-tensorflow) - [Advanced Flower with PyTorch](https://github.com/adap/flower/tree/main/examples/advanced-pytorch) -- Single-Machine Simulation of Federated Learning Systems ([PyTorch](https://github.com/adap/flower/tree/main/examples/simulation_pytorch)) ([Tensorflow](https://github.com/adap/flower/tree/main/examples/simulation_tensorflow)) +- Single-Machine Simulation of Federated Learning Systems ([PyTorch](https://github.com/adap/flower/tree/main/examples/simulation-pytorch)) ([Tensorflow](https://github.com/adap/flower/tree/main/examples/simulation-tensorflow)) +- [Comprehensive Flower+XGBoost](https://github.com/adap/flower/tree/main/examples/xgboost-comprehensive) +- [Flower through Docker Compose and with Grafana dashboard](https://github.com/adap/flower/tree/main/examples/flower-via-docker-compose) +- [Flower with KaplanMeierFitter from the lifelines library](https://github.com/adap/flower/tree/main/examples/federated-kaplna-meier-fitter) ## Community -Flower is built by a wonderful community of researchers and engineers. [Join Slack](https://flower.dev/join-slack) to meet them, [contributions](#contributing-to-flower) are welcome. +Flower is built by a wonderful community of researchers and engineers. [Join Slack](https://flower.ai/join-slack) to meet them, [contributions](#contributing-to-flower) are welcome. diff --git a/baselines/README.md b/baselines/README.md index a18c0553b2b4..3a84df02d8de 100644 --- a/baselines/README.md +++ b/baselines/README.md @@ -1,7 +1,7 @@ # Flower Baselines -> We are changing the way we structure the Flower baselines. While we complete the transition to the new format, you can still find the existing baselines in the `flwr_baselines` directory. Currently, you can make use of baselines for [FedAvg](https://github.com/adap/flower/tree/main/baselines/flwr_baselines/flwr_baselines/publications/fedavg_mnist), [FedOpt](https://github.com/adap/flower/tree/main/baselines/flwr_baselines/flwr_baselines/publications/adaptive_federated_optimization), and [LEAF-FEMNIST](https://github.com/adap/flower/tree/main/baselines/flwr_baselines/flwr_baselines/publications/leaf/femnist). +> We are changing the way we structure the Flower baselines. While we complete the transition to the new format, you can still find the existing baselines in the `flwr_baselines` directory. Currently, you can make use of baselines for [FedAvg](https://github.com/adap/flower/tree/main/baselines/flwr_baselines/flwr_baselines/publications/fedavg_mnist), [FedOpt](https://github.com/adap/flower/tree/main/baselines/flwr_baselines/flwr_baselines/publications/adaptive_federated_optimization), and [LEAF-FEMNIST](https://github.com/adap/flower/tree/main/baselines/flwr_baselines/flwr_baselines/publications/leaf/femnist). > The documentation below has been updated to reflect the new way of using Flower baselines. @@ -23,7 +23,7 @@ Please note that some baselines might include additional files (e.g. a `requirem ## Running the baselines -Each baseline is self-contained in its own directory. Furthermore, each baseline defines its own Python environment using [Poetry](https://python-poetry.org/docs/) via a `pyproject.toml` file and [`pyenv`](https://github.com/pyenv/pyenv). If you haven't setup `Poetry` and `pyenv` already on your machine, please take a look at the [Documentation](https://flower.dev/docs/baselines/how-to-use-baselines.html#setting-up-your-machine) for a guide on how to do so. +Each baseline is self-contained in its own directory. Furthermore, each baseline defines its own Python environment using [Poetry](https://python-poetry.org/docs/) via a `pyproject.toml` file and [`pyenv`](https://github.com/pyenv/pyenv). If you haven't setup `Poetry` and `pyenv` already on your machine, please take a look at the [Documentation](https://flower.ai/docs/baselines/how-to-use-baselines.html#setting-up-your-machine) for a guide on how to do so. Assuming `pyenv` and `Poetry` are already installed on your system. Running a baseline can be done by: @@ -54,7 +54,7 @@ The steps to follow are: ```bash # This will create a new directory with the same structure as `baseline_template`. ./dev/create-baseline.sh - ``` + ``` 3. Then, go inside your baseline directory and continue with the steps detailed in `EXTENDED_README.md` and `README.md`. 4. Once your code is ready and you have checked that following the instructions in your `README.md` the Python environment can be created correctly and that running the code following your instructions can reproduce the experiments in the paper, you just need to create a Pull Request (PR). Then, the process to merge your baseline into the Flower repo will begin! diff --git a/baselines/baseline_template/pyproject.toml b/baselines/baseline_template/pyproject.toml index da7516437f09..31f1ee7bfe6d 100644 --- a/baselines/baseline_template/pyproject.toml +++ b/baselines/baseline_template/pyproject.toml @@ -7,11 +7,11 @@ name = "" # <----- Ensure it matches the name of your baseline di version = "1.0.0" description = "Flower Baselines" license = "Apache-2.0" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] readme = "README.md" -homepage = "https://flower.dev" +homepage = "https://flower.ai" repository = "https://github.com/adap/flower" -documentation = "https://flower.dev" +documentation = "https://flower.ai" classifiers = [ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", @@ -42,9 +42,9 @@ flwr = { extras = ["simulation"], version = "1.5.0" } hydra-core = "1.3.2" # don't change this [tool.poetry.dev-dependencies] -isort = "==5.11.5" -black = "==23.1.0" -docformatter = "==1.5.1" +isort = "==5.13.2" +black = "==24.2.0" +docformatter = "==1.7.5" mypy = "==1.4.1" pylint = "==2.8.2" flake8 = "==3.9.2" @@ -80,10 +80,10 @@ plugins = "numpy.typing.mypy_plugin" [tool.pylint."MESSAGES CONTROL"] disable = "bad-continuation,duplicate-code,too-few-public-methods,useless-import-alias" good-names = "i,j,k,_,x,y,X,Y" -signature-mutators="hydra.main.main" +signature-mutators = "hydra.main.main" [tool.pylint.typecheck] -generated-members="numpy.*, torch.*, tensorflow.*" +generated-members = "numpy.*, torch.*, tensorflow.*" [[tool.mypy.overrides]] module = [ diff --git a/baselines/dasha/pyproject.toml b/baselines/dasha/pyproject.toml index 3ef24e4b985a..f03ad06e26e4 100644 --- a/baselines/dasha/pyproject.toml +++ b/baselines/dasha/pyproject.toml @@ -9,9 +9,9 @@ description = "DASHA: Distributed nonconvex optimization with communication comp license = "Apache-2.0" authors = ["Alexander Tyurin "] readme = "README.md" -homepage = "https://flower.dev" +homepage = "https://flower.ai" repository = "https://github.com/adap/flower" -documentation = "https://flower.dev" +documentation = "https://flower.ai" classifiers = [ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", @@ -49,9 +49,9 @@ torchvision = [{ url = "https://download.pytorch.org/whl/cu118/torchvision-0.15. { url = "https://download.pytorch.org/whl/cpu/torchvision-0.15.0-cp39-cp39-macosx_11_0_arm64.whl", markers="sys_platform == 'darwin'"}] [tool.poetry.dev-dependencies] -isort = "==5.11.5" -black = "==23.1.0" -docformatter = "==1.5.1" +isort = "==5.13.2" +black = "==24.2.0" +docformatter = "==1.7.5" mypy = "==1.4.1" pylint = "==2.8.2" flake8 = "==3.9.2" @@ -60,6 +60,7 @@ pytest-watch = "==4.2.0" types-requests = "==2.27.7" py-spy = "==0.3.14" ruff = "==0.0.272" +virtualenv = "==20.21.0" [tool.isort] line_length = 88 @@ -88,8 +89,8 @@ plugins = "numpy.typing.mypy_plugin" [tool.pylint."MESSAGES CONTROL"] disable = "bad-continuation,duplicate-code,too-few-public-methods,useless-import-alias,no-self-use,too-many-locals,too-many-instance-attributes" good-names = "i,j,k,_,x,y,X,Y" -signature-mutators="hydra.main.main" -generated-members="numpy.*, torch.*" +signature-mutators = "hydra.main.main" +generated-members = "numpy.*, torch.*" [[tool.mypy.overrides]] module = [ diff --git a/baselines/depthfl/README.md b/baselines/depthfl/README.md index b8ab7ed18571..ab26f9c8f8c9 100644 --- a/baselines/depthfl/README.md +++ b/baselines/depthfl/README.md @@ -124,7 +124,7 @@ python -m depthfl.main --multirun exclusive_learning=true model_size=1,2,3,4 **Table 2** -100% (a), 75%(b), 50%(c), 25% (d) cases are exclusive learning scenario. 100% (a) exclusive learning means, the global model and every local model are equal to the smallest local model, and 100% clients participate in learning. Likewise, 25% (d) exclusive learning means, the global model and every local model are equal to the larget local model, and only 25% clients participate in learning. +100% (a), 75%(b), 50%(c), 25% (d) cases are exclusive learning scenario. 100% (a) exclusive learning means, the global model and every local model are equal to the smallest local model, and 100% clients participate in learning. Likewise, 25% (d) exclusive learning means, the global model and every local model are equal to the largest local model, and only 25% clients participate in learning. | Scaling Method | Dataset | Global Model | 100% (a) | 75% (b) | 50% (c) | 25% (d) | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | diff --git a/baselines/depthfl/depthfl/models.py b/baselines/depthfl/depthfl/models.py index df3eebf9f9ce..5ce5eedb360e 100644 --- a/baselines/depthfl/depthfl/models.py +++ b/baselines/depthfl/depthfl/models.py @@ -1,6 +1,5 @@ """ResNet18 model architecutre, training, and testing functions for CIFAR100.""" - from typing import List, Tuple import torch diff --git a/baselines/depthfl/pyproject.toml b/baselines/depthfl/pyproject.toml index 2f928c2d3553..e59d16fef88e 100644 --- a/baselines/depthfl/pyproject.toml +++ b/baselines/depthfl/pyproject.toml @@ -9,9 +9,9 @@ description = "DepthFL: Depthwise Federated Learning for Heterogeneous Clients" license = "Apache-2.0" authors = ["Minjae Kim "] readme = "README.md" -homepage = "https://flower.dev" +homepage = "https://flower.ai" repository = "https://github.com/adap/flower" -documentation = "https://flower.dev" +documentation = "https://flower.ai" classifiers = [ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", @@ -37,18 +37,17 @@ classifiers = [ ] [tool.poetry.dependencies] -python = ">=3.10.0, <3.11.0" +python = ">=3.10.0, <3.11.0" flwr = { extras = ["simulation"], version = "1.5.0" } hydra-core = "1.3.2" # don't change this matplotlib = "3.7.1" -torch = { url = "https://download.pytorch.org/whl/cu116/torch-1.13.1%2Bcu116-cp310-cp310-linux_x86_64.whl"} -torchvision = { url = "https://download.pytorch.org/whl/cu116/torchvision-0.14.1%2Bcu116-cp310-cp310-linux_x86_64.whl"} - +torch = { url = "https://download.pytorch.org/whl/cu116/torch-1.13.1%2Bcu116-cp310-cp310-linux_x86_64.whl" } +torchvision = { url = "https://download.pytorch.org/whl/cu116/torchvision-0.14.1%2Bcu116-cp310-cp310-linux_x86_64.whl" } [tool.poetry.dev-dependencies] -isort = "==5.11.5" -black = "==23.1.0" -docformatter = "==1.5.1" +isort = "==5.13.2" +black = "==24.2.0" +docformatter = "==1.7.5" mypy = "==1.4.1" pylint = "==2.8.2" flake8 = "==3.9.2" @@ -56,6 +55,7 @@ pytest = "==6.2.4" pytest-watch = "==4.2.0" ruff = "==0.0.272" types-requests = "==2.27.7" +virtualenv = "==20.21.0" [tool.isort] line_length = 88 @@ -84,10 +84,10 @@ plugins = "numpy.typing.mypy_plugin" [tool.pylint."MESSAGES CONTROL"] disable = "bad-continuation,duplicate-code,too-few-public-methods,useless-import-alias" good-names = "i,j,k,_,x,y,X,Y" -signature-mutators="hydra.main.main" +signature-mutators = "hydra.main.main" [tool.pylint.typecheck] -generated-members="numpy.*, torch.*, tensorflow.*" +generated-members = "numpy.*, torch.*, tensorflow.*" [[tool.mypy.overrides]] module = [ diff --git a/baselines/doc/source/_templates/base.html b/baselines/doc/source/_templates/base.html index 1cee99053fbe..cc171be9e6b0 100644 --- a/baselines/doc/source/_templates/base.html +++ b/baselines/doc/source/_templates/base.html @@ -5,7 +5,7 @@ - + {%- if metatags %}{{ metatags }}{% endif -%} @@ -99,6 +99,6 @@ {%- endblock -%} {%- endblock scripts -%} - + diff --git a/baselines/doc/source/conf.py b/baselines/doc/source/conf.py index dabd421c61cf..dad8650cddaa 100644 --- a/baselines/doc/source/conf.py +++ b/baselines/doc/source/conf.py @@ -85,7 +85,7 @@ html_title = f"Flower Baselines {release}" html_logo = "_static/flower-logo.png" html_favicon = "_static/favicon.ico" -html_baseurl = "https://flower.dev/docs/baselines/" +html_baseurl = "https://flower.ai/docs/baselines/" html_theme_options = { # diff --git a/baselines/doc/source/how-to-use-baselines.rst b/baselines/doc/source/how-to-use-baselines.rst index ed47438ad5a9..4704a9b6074e 100644 --- a/baselines/doc/source/how-to-use-baselines.rst +++ b/baselines/doc/source/how-to-use-baselines.rst @@ -33,7 +33,7 @@ Setting up your machine Common to all baselines is `Poetry `_, a tool to manage Python dependencies. Baselines also make use of `Pyenv `_. You'll need to install both on your system before running a baseline. What follows is a step-by-step guide on getting :code:`pyenv` and :code:`Poetry` installed on your system. -Let's begin by installing :code:`pyenv`. We'll be following the standard procedure. Please refere to the `pyenv docs `_ for alternative ways of installing it. +Let's begin by installing :code:`pyenv`. We'll be following the standard procedure. Please refer to the `pyenv docs `_ for alternative ways of installing it. .. code-block:: bash @@ -49,7 +49,7 @@ Let's begin by installing :code:`pyenv`. We'll be following the standard procedu command -v pyenv >/dev/null || export PATH="$PYENV_ROOT/bin:$PATH" eval "$(pyenv init -)" -Verify your installtion by opening a new terminal and +Verify your installation by opening a new terminal and .. code-block:: bash @@ -63,7 +63,7 @@ Then you can proceed and install any version of Python. Most baselines currently pyenv install 3.10.6 # this will take a little while - # once done, you should see that that version is avaialble + # once done, you should see that that version is available pyenv versions # system # * 3.10.6 # <-- you just installed this diff --git a/baselines/doc/source/index.rst b/baselines/doc/source/index.rst index 335cfacef1ab..3a19e74b891e 100644 --- a/baselines/doc/source/index.rst +++ b/baselines/doc/source/index.rst @@ -1,7 +1,7 @@ Flower Baselines Documentation ============================== -Welcome to Flower Baselines' documentation. `Flower `_ is a friendly federated learning framework. +Welcome to Flower Baselines' documentation. `Flower `_ is a friendly federated learning framework. Join the Flower Community @@ -9,7 +9,7 @@ Join the Flower Community The Flower Community is growing quickly - we're a friendly group of researchers, engineers, students, professionals, academics, and other enthusiasts. -.. button-link:: https://flower.dev/join-slack +.. button-link:: https://flower.ai/join-slack :color: primary :shadow: diff --git a/baselines/fedavgm/README.md b/baselines/fedavgm/README.md index 0953331964a7..5b8ddfcad6a8 100644 --- a/baselines/fedavgm/README.md +++ b/baselines/fedavgm/README.md @@ -104,7 +104,7 @@ poetry shell ``` ### Google Colab -If you want to setup the environemnt on Google Colab, please executed the script `conf-colab.sh`, just use the Colab terminal and the following: +If you want to setup the environment on Google Colab, please executed the script `conf-colab.sh`, just use the Colab terminal and the following: ```bash chmod +x conf-colab.sh diff --git a/baselines/fedavgm/pyproject.toml b/baselines/fedavgm/pyproject.toml index 298deafd8932..d222baa65b0e 100644 --- a/baselines/fedavgm/pyproject.toml +++ b/baselines/fedavgm/pyproject.toml @@ -3,15 +3,15 @@ requires = ["poetry-core>=1.4.0"] build-backend = "poetry.masonry.api" [tool.poetry] -name = "fedavgm" +name = "fedavgm" version = "1.0.0" description = "FedAvgM: Measuring the effects of non-identical data distribution for federated visual classification" license = "Apache-2.0" authors = ["Gustavo Bertoli"] readme = "README.md" -homepage = "https://flower.dev" +homepage = "https://flower.ai" repository = "https://github.com/adap/flower" -documentation = "https://flower.dev" +documentation = "https://flower.ai" classifiers = [ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", @@ -38,18 +38,17 @@ classifiers = [ [tool.poetry.dependencies] python = ">=3.9, <3.12.0" # changed! original baseline template uses >= 3.8.15 -flwr = "1.5.0" -ray = "2.6.3" +flwr = { extras = ["simulation"], version = "1.5.0" } hydra-core = "1.3.2" # don't change this cython = "^3.0.0" -tensorflow = "2.10" +tensorflow = "2.11.1" numpy = "1.25.2" matplotlib = "^3.7.2" [tool.poetry.dev-dependencies] -isort = "==5.11.5" -black = "==23.1.0" -docformatter = "==1.5.1" +isort = "==5.13.2" +black = "==24.2.0" +docformatter = "==1.7.5" mypy = "==1.4.1" pylint = "==2.8.2" flake8 = "==3.9.2" @@ -57,6 +56,7 @@ pytest = "==6.2.4" pytest-watch = "==4.2.0" ruff = "==0.0.272" types-requests = "==2.27.7" +virtualenv = "==20.21.0" [tool.isort] line_length = 88 @@ -85,7 +85,7 @@ plugins = "numpy.typing.mypy_plugin" [tool.pylint."MESSAGES CONTROL"] disable = "bad-continuation,duplicate-code,too-few-public-methods,useless-import-alias" good-names = "i,j,k,_,x,y,X,Y" -signature-mutators="hydra.main.main" +signature-mutators = "hydra.main.main" [[tool.mypy.overrides]] module = [ diff --git a/baselines/fedbn/README.md b/baselines/fedbn/README.md index 4b271bd49851..37388894ff9d 100644 --- a/baselines/fedbn/README.md +++ b/baselines/fedbn/README.md @@ -34,37 +34,37 @@ dataset: [MNIST, MNIST-M, SVHN, USPS, SynthDigits] **Model:** A six-layer CNN with 14,219,210 parameters following the structure described in appendix D.2. -**Dataset:** This baseline makes use of the pre-processed partitions created and open source by the authors of the FedBN paper. You can read more about how those were created [here](https://github.com/med-air/FedBN). Follow the steps below in the `Environment Setup` section to download them. +**Dataset:** This baseline makes use of the pre-processed partitions created and open source by the authors of the FedBN paper. You can read more about how those were created [here](https://github.com/med-air/FedBN). Follow the steps below in the `Environment Setup` section to download them. A more detailed explanation of the datasets is given in the following table. -| | MNIST | MNIST-M | SVHN | USPS | SynthDigits | -|--- |--- |--- |--- |--- |--- | -| data type| handwritten digits| MNIST modification randomly colored with colored patches| Street view house numbers | handwritten digits from envelopes by the U.S. Postal Service | Syntehtic digits Windows TM font varying the orientation, blur and stroke colors | -| color | greyscale | RGB | RGB | greyscale | RGB | -| pixelsize | 28x28 | 28 x 28 | 32 x32 | 16 x16 | 32 x32 | -| labels | 0-9 | 0-9 | 1-10 | 0-9 | 1-10 | -| number of trainset | 60.000 | 60.000 | 73.257 | 9,298 | 50.000 | -| number of testset| 10.000 | 10.000 | 26.032 | - | - | -| image shape | (28,28) | (28,28,3) | (32,32,3) | (16,16) | (32,32,3) | +| | MNIST | MNIST-M | SVHN | USPS | SynthDigits | +| ------------------ | ------------------ | -------------------------------------------------------- | ------------------------- | ------------------------------------------------------------ | -------------------------------------------------------------------------------- | +| data type | handwritten digits | MNIST modification randomly colored with colored patches | Street view house numbers | handwritten digits from envelopes by the U.S. Postal Service | Synthetic digits Windows TM font varying the orientation, blur and stroke colors | +| color | greyscale | RGB | RGB | greyscale | RGB | +| pixelsize | 28x28 | 28 x 28 | 32 x32 | 16 x16 | 32 x32 | +| labels | 0-9 | 0-9 | 1-10 | 0-9 | 1-10 | +| number of trainset | 60.000 | 60.000 | 73.257 | 9,298 | 50.000 | +| number of testset | 10.000 | 10.000 | 26.032 | - | - | +| image shape | (28,28) | (28,28,3) | (32,32,3) | (16,16) | (32,32,3) | **Training Hyperparameters:** By default (i.e. if you don't override anything in the config) these main hyperparameters used are shown in the table below. For a complete list of hyperparameters, please refer to the config files in `fedbn/conf`. -| Description | Value | -| ----------- | ----- | -| rounds | 10 | -| num_clients | 5 | -| strategy_fraction_fit | 1.0 | -| strategy.fraction_evaluate | 0.0 | -| training samples per client| 743 | -| client.l_r | 10E-2 | -| local epochs | 1 | -| loss | cross entropy loss | -| optimizer | SGD | -| client_resources.num_cpu | 2 | -| client_resources.num_gpus | 0.0 | +| Description | Value | +| --------------------------- | ------------------ | +| rounds | 10 | +| num_clients | 5 | +| strategy_fraction_fit | 1.0 | +| strategy.fraction_evaluate | 0.0 | +| training samples per client | 743 | +| client.l_r | 10E-2 | +| local epochs | 1 | +| loss | cross entropy loss | +| optimizer | SGD | +| client_resources.num_cpu | 2 | +| client_resources.num_gpus | 0.0 | ## Environment Setup @@ -93,7 +93,7 @@ cd data .. ## Running the Experiments -First, activate your environment via `poetry shell`. The commands below show how to run the experiments and modify some of its key hyperparameters via the cli. Each time you run an experiment, the log and results will be stored inside `outputs//
- 🇬🇧 - 🇫🇷 + 🇬🇧 + 🇫🇷 + 🇨🇳
{% endif %} diff --git a/doc/source/_templates/sidebar/versioning.html b/doc/source/_templates/sidebar/versioning.html index dde7528d15e4..74f1cd8febb7 100644 --- a/doc/source/_templates/sidebar/versioning.html +++ b/doc/source/_templates/sidebar/versioning.html @@ -59,8 +59,8 @@ -
- +
+
diff --git a/doc/source/conf.py b/doc/source/conf.py index 503f76cb9eca..88cb5c05b1d8 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -86,7 +86,7 @@ author = "The Flower Authors" # The full version, including alpha/beta/rc tags -release = "1.7.0" +release = "1.8.0" # -- General configuration --------------------------------------------------- @@ -123,25 +123,27 @@ # The full name is still at the top of the page add_module_names = False + def find_test_modules(package_path): """Go through the python files and exclude every *_test.py file.""" full_path_modules = [] for root, dirs, files in os.walk(package_path): for file in files: - if file.endswith('_test.py'): + if file.endswith("_test.py"): # Construct the module path relative to the package directory full_path = os.path.join(root, file) relative_path = os.path.relpath(full_path, package_path) # Convert file path to dotted module path - module_path = os.path.splitext(relative_path)[0].replace(os.sep, '.') + module_path = os.path.splitext(relative_path)[0].replace(os.sep, ".") full_path_modules.append(module_path) modules = [] for full_path_module in full_path_modules: - parts = full_path_module.split('.') + parts = full_path_module.split(".") for i in range(len(parts)): - modules.append('.'.join(parts[i:])) + modules.append(".".join(parts[i:])) return modules + # Stop from documenting the *_test.py files. # That's the only way to do that in autosummary (make the modules as mock_imports). autodoc_mock_imports = find_test_modules(os.path.abspath("../../src/py/flwr")) @@ -173,6 +175,7 @@ def find_test_modules(package_path): "writing-documentation": "contributor-how-to-write-documentation.html", "apiref-binaries": "ref-api-cli.html", "fedbn-example-pytorch-from-centralized-to-federated": "example-fedbn-pytorch-from-centralized-to-federated.html", + "how-to-use-built-in-middleware-layers": "how-to-use-built-in-mods.html", # Restructuring: tutorials "tutorial/Flower-0-What-is-FL": "tutorial-series-what-is-federated-learning.html", "tutorial/Flower-1-Intro-to-FL-PyTorch": "tutorial-series-get-started-with-flower-pytorch.html", @@ -248,7 +251,7 @@ def find_test_modules(package_path): html_title = f"Flower Framework" html_logo = "_static/flower-logo.png" html_favicon = "_static/favicon.ico" -html_baseurl = "https://flower.dev/docs/framework/" +html_baseurl = "https://flower.ai/docs/framework/" html_theme_options = { # diff --git a/doc/source/contributor-explanation-architecture.rst b/doc/source/contributor-explanation-architecture.rst index 0e2ea1f6e66b..a20a84313118 100644 --- a/doc/source/contributor-explanation-architecture.rst +++ b/doc/source/contributor-explanation-architecture.rst @@ -4,7 +4,7 @@ Flower Architecture Edge Client Engine ------------------ -`Flower `_ core framework architecture with Edge Client Engine +`Flower `_ core framework architecture with Edge Client Engine .. figure:: _static/flower-architecture-ECE.png :width: 80 % @@ -12,7 +12,7 @@ Edge Client Engine Virtual Client Engine --------------------- -`Flower `_ core framework architecture with Virtual Client Engine +`Flower `_ core framework architecture with Virtual Client Engine .. figure:: _static/flower-architecture-VCE.png :width: 80 % @@ -20,7 +20,7 @@ Virtual Client Engine Virtual Client Engine and Edge Client Engine in the same workload ----------------------------------------------------------------- -`Flower `_ core framework architecture with both Virtual Client Engine and Edge Client Engine +`Flower `_ core framework architecture with both Virtual Client Engine and Edge Client Engine .. figure:: _static/flower-architecture.drawio.png :width: 80 % diff --git a/doc/source/contributor-how-to-build-docker-images.rst b/doc/source/contributor-how-to-build-docker-images.rst index d85e48155de0..5dead265bee2 100644 --- a/doc/source/contributor-how-to-build-docker-images.rst +++ b/doc/source/contributor-how-to-build-docker-images.rst @@ -17,7 +17,7 @@ Before we can start, we need to meet a few prerequisites in our local developmen #. Verify the Docker daemon is running. Please follow the first section on - `Run Flower using Docker `_ + :doc:`Run Flower using Docker ` which covers this step in more detail. Currently, Flower provides two images, a base image and a server image. There will also be a client @@ -98,17 +98,17 @@ Building the server image * - ``FLWR_VERSION`` - Version of Flower to be installed. - Yes - - ``1.6.0`` + - ``1.7.0`` The following example creates a server image with the official Flower base image py3.11-ubuntu22.04 -and Flower 1.6.0: +and Flower 1.7.0: .. code-block:: bash $ cd src/docker/server/ $ docker build \ --build-arg BASE_IMAGE_TAG=py3.11-ubuntu22.04 \ - --build-arg FLWR_VERSION=1.6.0 \ + --build-arg FLWR_VERSION=1.7.0 \ -t flwr_server:0.1.0 . The name of image is ``flwr_server`` and the tag ``0.1.0``. Remember that the build arguments as well @@ -125,7 +125,7 @@ the tag of your image. $ docker build \ --build-arg BASE_REPOSITORY=flwr_base \ --build-arg BASE_IMAGE_TAG=0.1.0 \ - --build-arg FLWR_VERSION=1.6.0 \ + --build-arg FLWR_VERSION=1.7.0 \ -t flwr_server:0.1.0 . After creating the image, we can test whether the image is working: diff --git a/doc/source/contributor-how-to-contribute-translations.rst b/doc/source/contributor-how-to-contribute-translations.rst index d97a2cb8c64f..ba59901cf1c4 100644 --- a/doc/source/contributor-how-to-contribute-translations.rst +++ b/doc/source/contributor-how-to-contribute-translations.rst @@ -2,13 +2,13 @@ Contribute translations ======================= Since `Flower 1.5 -`_ we +`_ we have introduced translations to our doc pages, but, as you might have noticed, the translations are often imperfect. If you speak languages other than English, you might be able to help us in our effort to make Federated Learning accessible to as many people as possible by contributing to those translations! This might also be a great opportunity for those wanting to become open source -contributors with little prerequistes. +contributors with little prerequisites. Our translation project is publicly available over on `Weblate `_, this where most @@ -44,7 +44,7 @@ This is what the interface looks like: .. image:: _static/weblate_interface.png -You input your translation in the textbox at the top and then, once you are +You input your translation in the text box at the top and then, once you are happy with it, you either press ``Save and continue`` (to save the translation and go to the next untranslated string), ``Save and stay`` (to save the translation and stay on the same page), ``Suggest`` (to add your translation to @@ -67,5 +67,5 @@ Add new languages ----------------- If you want to add a new language, you will first have to contact us, either on -`Slack `_, or by opening an issue on our `GitHub +`Slack `_, or by opening an issue on our `GitHub repo `_. diff --git a/doc/source/contributor-how-to-create-new-messages.rst b/doc/source/contributor-how-to-create-new-messages.rst index 24fa5f573158..3f1849bdce47 100644 --- a/doc/source/contributor-how-to-create-new-messages.rst +++ b/doc/source/contributor-how-to-create-new-messages.rst @@ -29,8 +29,8 @@ Let's now see what we need to implement in order to get this simple function bet Message Types for Protocol Buffers ---------------------------------- -The first thing we need to do is to define a message type for the RPC system in :code:`transport.proto`. -Note that we have to do it for both the request and response messages. For more details on the syntax of proto3, please see the `official documentation `_. +The first thing we need to do is to define a message type for the RPC system in :code:`transport.proto`. +Note that we have to do it for both the request and response messages. For more details on the syntax of proto3, please see the `official documentation `_. Within the :code:`ServerMessage` block: @@ -75,7 +75,7 @@ Once that is done, we will compile the file with: $ python -m flwr_tool.protoc -If it compiles succesfully, you should see the following message: +If it compiles successfully, you should see the following message: .. code-block:: shell diff --git a/doc/source/contributor-how-to-develop-in-vscode-dev-containers.rst b/doc/source/contributor-how-to-develop-in-vscode-dev-containers.rst index 19d46c5753c6..c861457b6edc 100644 --- a/doc/source/contributor-how-to-develop-in-vscode-dev-containers.rst +++ b/doc/source/contributor-how-to-develop-in-vscode-dev-containers.rst @@ -8,7 +8,7 @@ When working on the Flower framework we want to ensure that all contributors use Workspace files are mounted from the local file system or copied or cloned into the container. Extensions are installed and run inside the container, where they have full access to the tools, platform, and file system. This means that you can seamlessly switch your entire development environment just by connecting to a different container. -Source: `Official VSCode documentation `_ +Source: `Official VSCode documentation `_ Getting started @@ -20,5 +20,5 @@ Now you should be good to go. When starting VSCode, it will ask you to run in th In some cases your setup might be more involved. For those cases consult the following sources: -* `Developing inside a Container `_ -* `Remote development in Containers `_ +* `Developing inside a Container `_ +* `Remote development in Containers `_ diff --git a/doc/source/contributor-how-to-install-development-versions.rst b/doc/source/contributor-how-to-install-development-versions.rst index 243f4ef97e8e..558ec7f8ec46 100644 --- a/doc/source/contributor-how-to-install-development-versions.rst +++ b/doc/source/contributor-how-to-install-development-versions.rst @@ -19,8 +19,8 @@ Install ``flwr`` from a local copy of the Flower source code via ``pyproject.tom Install ``flwr`` from a local wheel file via ``pyproject.toml``: -- ``flwr = { path = "../../dist/flwr-1.0.0-py3-none-any.whl" }`` (without extras) -- ``flwr = { path = "../../dist/flwr-1.0.0-py3-none-any.whl", extras = ["simulation"] }`` (with extras) +- ``flwr = { path = "../../dist/flwr-1.8.0-py3-none-any.whl" }`` (without extras) +- ``flwr = { path = "../../dist/flwr-1.8.0-py3-none-any.whl", extras = ["simulation"] }`` (with extras) Please refer to the Poetry documentation for further details: `Poetry Dependency Specification `_ @@ -59,5 +59,5 @@ Open a development version of the same notebook from branch `branch-name` by cha Install a `whl` on Google Colab: 1. In the vertical icon grid on the left hand side, select ``Files`` > ``Upload to session storage`` -2. Upload the whl (e.g., ``flwr-1.7.0-py3-none-any.whl``) -3. Change ``!pip install -q 'flwr[simulation]' torch torchvision matplotlib`` to ``!pip install -q 'flwr-1.7.0-py3-none-any.whl[simulation]' torch torchvision matplotlib`` +2. Upload the whl (e.g., ``flwr-1.8.0-py3-none-any.whl``) +3. Change ``!pip install -q 'flwr[simulation]' torch torchvision matplotlib`` to ``!pip install -q 'flwr-1.8.0-py3-none-any.whl[simulation]' torch torchvision matplotlib`` diff --git a/doc/source/contributor-how-to-release-flower.rst b/doc/source/contributor-how-to-release-flower.rst index acfac4197ec1..4853d87bc4c1 100644 --- a/doc/source/contributor-how-to-release-flower.rst +++ b/doc/source/contributor-how-to-release-flower.rst @@ -22,7 +22,7 @@ Create a pull request which contains the following changes: 2. Update all files which contain the current version number if necessary. 3. Add a new ``Unreleased`` section in ``changelog.md``. -Merge the pull request on the same day (i.e., before a new nighly release gets published to PyPI). +Merge the pull request on the same day (i.e., before a new nightly release gets published to PyPI). Publishing a pre-release ------------------------ @@ -30,11 +30,11 @@ Publishing a pre-release Pre-release naming ~~~~~~~~~~~~~~~~~~ -PyPI supports pre-releases (alpha, beta, release candiate). Pre-releases MUST use one of the following naming patterns: +PyPI supports pre-releases (alpha, beta, release candidate). Pre-releases MUST use one of the following naming patterns: - Alpha: ``MAJOR.MINOR.PATCHaN`` - Beta: ``MAJOR.MINOR.PATCHbN`` -- Release candiate (RC): ``MAJOR.MINOR.PATCHrcN`` +- Release candidate (RC): ``MAJOR.MINOR.PATCHrcN`` Examples include: diff --git a/doc/source/contributor-ref-good-first-contributions.rst b/doc/source/contributor-ref-good-first-contributions.rst index 523a4679c6ef..2b8ce88413f5 100644 --- a/doc/source/contributor-ref-good-first-contributions.rst +++ b/doc/source/contributor-ref-good-first-contributions.rst @@ -14,7 +14,7 @@ Until the Flower core library matures it will be easier to get PR's accepted if they only touch non-core areas of the codebase. Good candidates to get started are: -- Documentation: What's missing? What could be expressed more clearly? +- Documentation: What's missing? What could be expressed more clearly? - Baselines: See below. - Examples: See below. @@ -22,11 +22,11 @@ are: Request for Flower Baselines ---------------------------- -If you are not familiar with Flower Baselines, you should probably check-out our `contributing guide for baselines `_. +If you are not familiar with Flower Baselines, you should probably check-out our `contributing guide for baselines `_. -You should then check out the open +You should then check out the open `issues `_ for baseline requests. -If you find a baseline that you'd like to work on and that has no assignes, feel free to assign it to yourself and start working on it! +If you find a baseline that you'd like to work on and that has no assignees, feel free to assign it to yourself and start working on it! Otherwise, if you don't find a baseline you'd like to work on, be sure to open a new issue with the baseline request template! diff --git a/doc/source/contributor-tutorial-contribute-on-github.rst b/doc/source/contributor-tutorial-contribute-on-github.rst index 351d2408d9f3..6da81ce73662 100644 --- a/doc/source/contributor-tutorial-contribute-on-github.rst +++ b/doc/source/contributor-tutorial-contribute-on-github.rst @@ -3,9 +3,7 @@ Contribute on GitHub This guide is for people who want to get involved with Flower, but who are not used to contributing to GitHub projects. -If you're familiar with how contributing on GitHub works, you can directly checkout our -`getting started guide for contributors `_ -and examples of `good first contributions `_. +If you're familiar with how contributing on GitHub works, you can directly checkout our :doc:`getting started guide for contributors `. Setting up the repository @@ -13,21 +11,21 @@ Setting up the repository 1. **Create a GitHub account and setup Git** Git is a distributed version control tool. This allows for an entire codebase's history to be stored and every developer's machine. - It is a software that will need to be installed on your local machine, you can follow this `guide `_ to set it up. + It is a software that will need to be installed on your local machine, you can follow this `guide `_ to set it up. GitHub, itself, is a code hosting platform for version control and collaboration. It allows for everyone to collaborate and work from anywhere on remote repositories. - If you haven't already, you will need to create an account on `GitHub `_. + If you haven't already, you will need to create an account on `GitHub `_. - The idea behind the generic Git and GitHub workflow boils down to this: + The idea behind the generic Git and GitHub workflow boils down to this: you download code from a remote repository on GitHub, make changes locally and keep track of them using Git and then you upload your new history back to GitHub. 2. **Forking the Flower repository** - A fork is a personal copy of a GitHub repository. To create one for Flower, you must navigate to https://github.com/adap/flower (while connected to your GitHub account) + A fork is a personal copy of a GitHub repository. To create one for Flower, you must navigate to ``_ (while connected to your GitHub account) and click the ``Fork`` button situated on the top right of the page. .. image:: _static/fork_button.png - + You can change the name if you want, but this is not necessary as this version of Flower will be yours and will sit inside your own account (i.e., in your own list of repositories). Once created, you should see on the top left corner that you are looking at your own version of Flower. @@ -35,14 +33,14 @@ Setting up the repository 3. **Cloning your forked repository** The next step is to download the forked repository on your machine to be able to make changes to it. - On your forked repository page, you should first click on the ``Code`` button on the right, + On your forked repository page, you should first click on the ``Code`` button on the right, this will give you the ability to copy the HTTPS link of the repository. .. image:: _static/cloning_fork.png Once you copied the \, you can open a terminal on your machine, navigate to the place you want to download the repository to and type: - .. code-block:: shell + .. code-block:: shell $ git clone @@ -59,17 +57,17 @@ Setting up the repository To obtain it, we can do as previously mentioned by going to our fork repository on our GitHub account and copying the link. .. image:: _static/cloning_fork.png - + Once the \ is copied, we can type the following command in our terminal: .. code-block:: shell $ git remote add origin - + 5. **Add upstream** Now we will add an upstream address to our repository. - Still in the same directroy, we must run the following command: + Still in the same directory, we must run the following command: .. code-block:: shell @@ -77,10 +75,10 @@ Setting up the repository The following diagram visually explains what we did in the previous steps: - .. image:: _static/github_schema.png + .. image:: _static/github_schema.png - The upstream is the GitHub remote address of the parent repository (in this case Flower), - i.e. the one we eventually want to contribute to and therefore need an up-to-date history of. + The upstream is the GitHub remote address of the parent repository (in this case Flower), + i.e. the one we eventually want to contribute to and therefore need an up-to-date history of. The origin is just the GitHub remote address of the forked repository we created, i.e. the copy (fork) in our own account. To make sure our local version of the fork is up-to-date with the latest changes from the Flower repository, @@ -94,7 +92,7 @@ Setting up the repository Setting up the coding environment --------------------------------- -This can be achieved by following this `getting started guide for contributors`_ (note that you won't need to clone the repository). +This can be achieved by following this :doc:`getting started guide for contributors ` (note that you won't need to clone the repository). Once you are able to write code and test it, you can finally start making changes! @@ -114,9 +112,9 @@ And with Flower's repository: $ git pull upstream main 1. **Create a new branch** - To make the history cleaner and easier to work with, it is good practice to + To make the history cleaner and easier to work with, it is good practice to create a new branch for each feature/project that needs to be implemented. - + To do so, just run the following command inside the repository's directory: .. code-block:: shell @@ -138,7 +136,7 @@ And with Flower's repository: $ ./dev/test.sh # to test that your code can be accepted $ ./baselines/dev/format.sh # same as above but for code added to baselines $ ./baselines/dev/test.sh # same as above but for code added to baselines - + 4. **Stage changes** Before creating a commit that will update your history, you must specify to Git which files it needs to take into account. @@ -185,21 +183,21 @@ Creating and merging a pull request (PR) Once you click the ``Compare & pull request`` button, you should see something similar to this: .. image:: _static/creating_pr.png - + At the top you have an explanation of which branch will be merged where: .. image:: _static/merging_branch.png - + In this example you can see that the request is to merge the branch ``doc-fixes`` from my forked repository to branch ``main`` from the Flower repository. - The input box in the middle is there for you to describe what your PR does and to link it to existing issues. + The input box in the middle is there for you to describe what your PR does and to link it to existing issues. We have placed comments (that won't be rendered once the PR is opened) to guide you through the process. It is important to follow the instructions described in comments. For instance, in order to not break how our changelog system works, you should read the information above the ``Changelog entry`` section carefully. You can also checkout some examples and details in the :ref:`changelogentry` appendix. - At the bottom you will find the button to open the PR. This will notify reviewers that a new PR has been opened and + At the bottom you will find the button to open the PR. This will notify reviewers that a new PR has been opened and that they should look over it to merge or to request changes. If your PR is not yet ready for review, and you don't want to notify anyone, you have the option to create a draft pull request: @@ -219,7 +217,7 @@ Creating and merging a pull request (PR) Merging will be blocked if there are ongoing requested changes. .. image:: _static/changes_requested.png - + To resolve them, just push the necessary changes to the branch associated with the PR: .. image:: _static/make_changes.png @@ -257,36 +255,36 @@ Example of first contribution Problem ******* -For our documentation, we’ve started to use the `Diàtaxis framework `_. +For our documentation, we've started to use the `Diàtaxis framework `_. -Our “How to” guides should have titles that continue the sencence “How to …”, for example, “How to upgrade to Flower 1.0”. +Our "How to" guides should have titles that continue the sentence "How to …", for example, "How to upgrade to Flower 1.0". Most of our guides do not follow this new format yet, and changing their title is (unfortunately) more involved than one might think. -This issue is about changing the title of a doc from present continious to present simple. +This issue is about changing the title of a doc from present continuous to present simple. -Let's take the example of “Saving Progress” which we changed to “Save Progress”. Does this pass our check? +Let's take the example of "Saving Progress" which we changed to "Save Progress". Does this pass our check? -Before: ”How to saving progress” ❌ +Before: "How to saving progress" ❌ -After: ”How to save progress” ✅ +After: "How to save progress" ✅ Solution ******** -This is a tiny change, but it’ll allow us to test your end-to-end setup. After cloning and setting up the Flower repo, here’s what you should do: +This is a tiny change, but it'll allow us to test your end-to-end setup. After cloning and setting up the Flower repo, here's what you should do: - Find the source file in ``doc/source`` - Make the change in the ``.rst`` file (beware, the dashes under the title should be the same length as the title itself) -- Build the docs and check the result: ``_ +- Build the docs and `check the result `_ Rename file ::::::::::: -You might have noticed that the file name still reflects the old wording. +You might have noticed that the file name still reflects the old wording. If we just change the file, then we break all existing links to it - it is **very important** to avoid that, breaking links can harm our search engine ranking. -Here’s how to change the file name: +Here's how to change the file name: - Change the file name to ``save-progress.rst`` - Add a redirect rule to ``doc/source/conf.py`` @@ -296,7 +294,7 @@ This will cause a redirect from ``saving-progress.html`` to ``save-progress.html Apply changes in the index file ::::::::::::::::::::::::::::::: -For the lateral navigation bar to work properly, it is very important to update the ``index.rst`` file as well. +For the lateral navigation bar to work properly, it is very important to update the ``index.rst`` file as well. This is where we define the whole arborescence of the navbar. - Find and modify the file name in ``index.rst`` @@ -304,7 +302,7 @@ This is where we define the whole arborescence of the navbar. Open PR ::::::: -- Commit the changes (commit messages are always imperative: “Do something”, in this case “Change …”) +- Commit the changes (commit messages are always imperative: "Do something", in this case "Change …") - Push the changes to your fork - Open a PR (as shown above) - Wait for it to be approved! @@ -344,7 +342,7 @@ Next steps Once you have made your first PR, and want to contribute more, be sure to check out the following : -- `Good first contributions `_, where you should particularly look into the :code:`baselines` contributions. +- :doc:`Good first contributions `, where you should particularly look into the :code:`baselines` contributions. Appendix @@ -359,16 +357,13 @@ When opening a new PR, inside its description, there should be a ``Changelog ent Above this header you should see the following comment that explains how to write your changelog entry: - Inside the following 'Changelog entry' section, + Inside the following 'Changelog entry' section, you should put the description of your changes that will be added to the changelog alongside your PR title. - If the section is completely empty (without any token), - the changelog will just contain the title of the PR for the changelog entry, without any description. - - If the 'Changelog entry' section is removed entirely, - it will categorize the PR as "General improvement" and add it to the changelog accordingly. + If the section is completely empty (without any token) or non-existent, + the changelog will just contain the title of the PR for the changelog entry, without any description. - If the section contains some text other than tokens, it will use it to add a description to the change. + If the section contains some text other than tokens, it will use it to add a description to the change. If the section contains one of the following tokens it will ignore any other text and put the PR under the corresponding section of the changelog: @@ -388,11 +383,7 @@ Above this header you should see the following comment that explains how to writ Its content must have a specific format. We will break down what each possibility does: -- If the ``### Changelog entry`` section is removed, the following text will be added to the changelog:: - - - **General improvements** ([#PR_NUMBER](https://github.com/adap/flower/pull/PR_NUMBER)) - -- If the ``### Changelog entry`` section contains nothing but exists, the following text will be added to the changelog:: +- If the ``### Changelog entry`` section contains nothing or doesn't exist, the following text will be added to the changelog:: - **PR TITLE** ([#PR_NUMBER](https://github.com/adap/flower/pull/PR_NUMBER)) diff --git a/doc/source/contributor-tutorial-get-started-as-a-contributor.rst b/doc/source/contributor-tutorial-get-started-as-a-contributor.rst index 72c6df5fdbc7..9136fea96bf6 100644 --- a/doc/source/contributor-tutorial-get-started-as-a-contributor.rst +++ b/doc/source/contributor-tutorial-get-started-as-a-contributor.rst @@ -11,7 +11,7 @@ Prerequisites Flower uses :code:`pyproject.toml` to manage dependencies and configure development tools (the ones which support it). Poetry is a build tool which -supports `PEP 517 `_. +supports `PEP 517 `_. Developer Machine Setup @@ -27,7 +27,7 @@ For macOS * Install `homebrew `_. Don't forget the post-installation actions to add `brew` to your PATH. * Install `xz` (to install different Python versions) and `pandoc` to build the docs:: - + $ brew install xz pandoc For Ubuntu @@ -54,7 +54,7 @@ GitHub:: * If you don't have :code:`pyenv` installed, the following script that will install it, set it up, and create the virtual environment (with :code:`Python 3.8.17` by default):: $ ./dev/setup-defaults.sh # once completed, run the bootstrap script - + * If you already have :code:`pyenv` installed (along with the :code:`pyenv-virtualenv` plugin), you can use the following convenience script (with :code:`Python 3.8.17` by default):: $ ./dev/venv-create.sh # once completed, run the `bootstrap.sh` script @@ -70,7 +70,7 @@ Convenience Scripts The Flower repository contains a number of convenience scripts to make recurring development tasks easier and less error-prone. See the :code:`/dev` -subdirectory for a full list. The following scripts are amonst the most +subdirectory for a full list. The following scripts are amongst the most important ones: Create/Delete Virtual Environment diff --git a/doc/source/example-fedbn-pytorch-from-centralized-to-federated.rst b/doc/source/example-fedbn-pytorch-from-centralized-to-federated.rst index 5ebaa337dde8..0139f3b8dc31 100644 --- a/doc/source/example-fedbn-pytorch-from-centralized-to-federated.rst +++ b/doc/source/example-fedbn-pytorch-from-centralized-to-federated.rst @@ -3,11 +3,11 @@ Example: FedBN in PyTorch - From Centralized To Federated This tutorial will show you how to use Flower to build a federated version of an existing machine learning workload with `FedBN `_, a federated training strategy designed for non-iid data. We are using PyTorch to train a Convolutional Neural Network(with Batch Normalization layers) on the CIFAR-10 dataset. -When applying FedBN, only few changes needed compared to `Example: PyTorch - From Centralized To Federated `_. +When applying FedBN, only few changes needed compared to :doc:`Example: PyTorch - From Centralized To Federated `. Centralized Training -------------------- -All files are revised based on `Example: PyTorch - From Centralized To Federated `_. +All files are revised based on :doc:`Example: PyTorch - From Centralized To Federated `. The only thing to do is modifying the file called :code:`cifar.py`, revised part is shown below: The model architecture defined in class Net() is added with Batch Normalization layers accordingly. @@ -45,13 +45,13 @@ You can now run your machine learning workload: python3 cifar.py So far this should all look fairly familiar if you've used PyTorch before. -Let's take the next step and use what we've built to create a federated learning system within FedBN, the sytstem consists of one server and two clients. +Let's take the next step and use what we've built to create a federated learning system within FedBN, the system consists of one server and two clients. Federated Training ------------------ -If you have read `Example: PyTorch - From Centralized To Federated `_, the following parts are easy to follow, onyl :code:`get_parameters` and :code:`set_parameters` function in :code:`client.py` needed to revise. -If not, please read the `Example: PyTorch - From Centralized To Federated `_. first. +If you have read :doc:`Example: PyTorch - From Centralized To Federated `, the following parts are easy to follow, only :code:`get_parameters` and :code:`set_parameters` function in :code:`client.py` needed to revise. +If not, please read the :doc:`Example: PyTorch - From Centralized To Federated `. first. Our example consists of one *server* and two *clients*. In FedBN, :code:`server.py` keeps unchanged, we can start the server directly. @@ -66,7 +66,7 @@ Finally, we will revise our *client* logic by changing :code:`get_parameters` an class CifarClient(fl.client.NumPyClient): """Flower client implementing CIFAR-10 image classification using PyTorch.""" - + ... def get_parameters(self, config) -> List[np.ndarray]: @@ -79,7 +79,7 @@ Finally, we will revise our *client* logic by changing :code:`get_parameters` an params_dict = zip(keys, parameters) state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) self.model.load_state_dict(state_dict, strict=False) - + ... Now, you can now open two additional terminal windows and run diff --git a/doc/source/example-jax-from-centralized-to-federated.rst b/doc/source/example-jax-from-centralized-to-federated.rst index 2b1823e9d408..6b06a288a67a 100644 --- a/doc/source/example-jax-from-centralized-to-federated.rst +++ b/doc/source/example-jax-from-centralized-to-federated.rst @@ -259,7 +259,7 @@ Having defined the federation process, we can run it. # Start Flower client client = FlowerClient(params, grad_fn, train_x, train_y, test_x, test_y) - fl.client.start_numpy_client(server_address="0.0.0.0:8080", client) + fl.client.start_client(server_address="0.0.0.0:8080", client.to_client()) if __name__ == "__main__": main() diff --git a/doc/source/example-mxnet-walk-through.rst b/doc/source/example-mxnet-walk-through.rst index a395a8a723cb..c215f709ffb2 100644 --- a/doc/source/example-mxnet-walk-through.rst +++ b/doc/source/example-mxnet-walk-through.rst @@ -234,7 +234,7 @@ Our *client* needs to import :code:`flwr`, but also :code:`mxnet` to update the Implementing a Flower *client* basically means implementing a subclass of either :code:`flwr.client.Client` or :code:`flwr.client.NumPyClient`. Our implementation will be based on :code:`flwr.client.NumPyClient` and we'll call it :code:`MNISTClient`. -:code:`NumPyClient` is slighly easier to implement than :code:`Client` if you use a framework with good NumPy interoperability (like PyTorch or MXNet) because it avoids some of the boilerplate that would otherwise be necessary. +:code:`NumPyClient` is slightly easier to implement than :code:`Client` if you use a framework with good NumPy interoperability (like PyTorch or MXNet) because it avoids some of the boilerplate that would otherwise be necessary. :code:`MNISTClient` needs to implement four methods, two methods for getting/setting model parameters, one method for training the model, and one method for testing the model: #. :code:`set_parameters (optional)` diff --git a/doc/source/example-pytorch-from-centralized-to-federated.rst b/doc/source/example-pytorch-from-centralized-to-federated.rst index d649658667da..0c458a136a81 100644 --- a/doc/source/example-pytorch-from-centralized-to-federated.rst +++ b/doc/source/example-pytorch-from-centralized-to-federated.rst @@ -174,7 +174,7 @@ However, with Flower you can evolve your pre-existing code into a federated lear The concept is easy to understand. We have to start a *server* and then use the code in :code:`cifar.py` for the *clients* that are connected to the *server*. -The *server* sends model parameters to the clients. The *clients* run the training and update the paramters. +The *server* sends model parameters to the clients. The *clients* run the training and update the parameters. The updated parameters are sent back to the *server* which averages all received parameter updates. This describes one round of the federated learning process and we repeat this for multiple rounds. @@ -195,7 +195,7 @@ We can already start the *server*: python3 server.py Finally, we will define our *client* logic in :code:`client.py` and build upon the previously defined centralized training in :code:`cifar.py`. -Our *client* needs to import :code:`flwr`, but also :code:`torch` to update the paramters on our PyTorch model: +Our *client* needs to import :code:`flwr`, but also :code:`torch` to update the parameters on our PyTorch model: .. code-block:: python @@ -212,7 +212,7 @@ Our *client* needs to import :code:`flwr`, but also :code:`torch` to update the Implementing a Flower *client* basically means implementing a subclass of either :code:`flwr.client.Client` or :code:`flwr.client.NumPyClient`. Our implementation will be based on :code:`flwr.client.NumPyClient` and we'll call it :code:`CifarClient`. -:code:`NumPyClient` is slighly easier to implement than :code:`Client` if you use a framework with good NumPy interoperability (like PyTorch or TensorFlow/Keras) because it avoids some of the boilerplate that would otherwise be necessary. +:code:`NumPyClient` is slightly easier to implement than :code:`Client` if you use a framework with good NumPy interoperability (like PyTorch or TensorFlow/Keras) because it avoids some of the boilerplate that would otherwise be necessary. :code:`CifarClient` needs to implement four methods, two methods for getting/setting model parameters, one method for training the model, and one method for testing the model: #. :code:`set_parameters` @@ -278,7 +278,7 @@ We included type annotations to give you a better understanding of the data type return float(loss), self.num_examples["testset"], {"accuracy": float(accuracy)} All that's left to do it to define a function that loads both model and data, creates a :code:`CifarClient`, and starts this client. -You load your data and model by using :code:`cifar.py`. Start :code:`CifarClient` with the function :code:`fl.client.start_numpy_client()` by pointing it at the same IP adress we used in :code:`server.py`: +You load your data and model by using :code:`cifar.py`. Start :code:`CifarClient` with the function :code:`fl.client.start_client()` by pointing it at the same IP address we used in :code:`server.py`: .. code-block:: python @@ -292,7 +292,7 @@ You load your data and model by using :code:`cifar.py`. Start :code:`CifarClient # Start client client = CifarClient(model, trainloader, testloader, num_examples) - fl.client.start_numpy_client(server_address="0.0.0.0:8080", client) + fl.client.start_client(server_address="0.0.0.0:8080", client.to_client()) if __name__ == "__main__": diff --git a/doc/source/example-walkthrough-pytorch-mnist.rst b/doc/source/example-walkthrough-pytorch-mnist.rst index ab311813f5de..f8eacc8647fe 100644 --- a/doc/source/example-walkthrough-pytorch-mnist.rst +++ b/doc/source/example-walkthrough-pytorch-mnist.rst @@ -1,11 +1,11 @@ Example: Walk-Through PyTorch & MNIST ===================================== -In this tutorial we will learn, how to train a Convolutional Neural Network on MNIST using Flower and PyTorch. +In this tutorial we will learn, how to train a Convolutional Neural Network on MNIST using Flower and PyTorch. -Our example consists of one *server* and two *clients* all having the same model. +Our example consists of one *server* and two *clients* all having the same model. -*Clients* are responsible for generating individual weight-updates for the model based on their local datasets. +*Clients* are responsible for generating individual weight-updates for the model based on their local datasets. These updates are then sent to the *server* which will aggregate them to produce a better model. Finally, the *server* sends this improved version of the model back to each *client*. A complete cycle of weight updates is called a *round*. @@ -15,7 +15,7 @@ Now that we have a rough idea of what is going on, let's get started. We first n $ pip install flwr -Since we want to use PyTorch to solve a computer vision task, let's go ahead an install PyTorch and the **torchvision** library: +Since we want to use PyTorch to solve a computer vision task, let's go ahead an install PyTorch and the **torchvision** library: .. code-block:: shell @@ -32,51 +32,51 @@ Go ahead and launch on a terminal the *run-server.sh* script first as follows: .. code-block:: shell - $ bash ./run-server.sh + $ bash ./run-server.sh -Now that the server is up and running, go ahead and launch the clients. +Now that the server is up and running, go ahead and launch the clients. .. code-block:: shell - $ bash ./run-clients.sh + $ bash ./run-clients.sh Et voilà! You should be seeing the training procedure and, after a few iterations, the test accuracy for each client. .. code-block:: shell - Train Epoch: 10 [30000/30016 (100%)] Loss: 0.007014 - - Train Epoch: 10 [30000/30016 (100%)] Loss: 0.000403 - - Train Epoch: 11 [30000/30016 (100%)] Loss: 0.001280 - - Train Epoch: 11 [30000/30016 (100%)] Loss: 0.000641 - - Train Epoch: 12 [30000/30016 (100%)] Loss: 0.006784 - - Train Epoch: 12 [30000/30016 (100%)] Loss: 0.007134 - - Client 1 - Evaluate on 5000 samples: Average loss: 0.0290, Accuracy: 99.16% - + Train Epoch: 10 [30000/30016 (100%)] Loss: 0.007014 + + Train Epoch: 10 [30000/30016 (100%)] Loss: 0.000403 + + Train Epoch: 11 [30000/30016 (100%)] Loss: 0.001280 + + Train Epoch: 11 [30000/30016 (100%)] Loss: 0.000641 + + Train Epoch: 12 [30000/30016 (100%)] Loss: 0.006784 + + Train Epoch: 12 [30000/30016 (100%)] Loss: 0.007134 + + Client 1 - Evaluate on 5000 samples: Average loss: 0.0290, Accuracy: 99.16% + Client 0 - Evaluate on 5000 samples: Average loss: 0.0328, Accuracy: 99.14% -Now, let's see what is really happening inside. +Now, let's see what is really happening inside. Flower Server ------------- Inside the server helper script *run-server.sh* you will find the following code that basically runs the :code:`server.py` -.. code-block:: bash +.. code-block:: bash python -m flwr_example.quickstart-pytorch.server We can go a bit deeper and see that :code:`server.py` simply launches a server that will coordinate three rounds of training. -Flower Servers are very customizable, but for simple workloads, we can start a server using the :ref:`start_server ` function and leave all the configuration possibilities at their default values, as seen below. +Flower Servers are very customizable, but for simple workloads, we can start a server using the `start_server `_ function and leave all the configuration possibilities at their default values, as seen below. .. code-block:: python @@ -90,18 +90,18 @@ Flower Client Next, let's take a look at the *run-clients.sh* file. You will see that it contains the main loop that starts a set of *clients*. -.. code-block:: bash +.. code-block:: bash python -m flwr_example.quickstart-pytorch.client \ --cid=$i \ --server_address=$SERVER_ADDRESS \ - --nb_clients=$NUM_CLIENTS + --nb_clients=$NUM_CLIENTS * **cid**: is the client ID. It is an integer that uniquely identifies client identifier. -* **sever_address**: String that identifies IP and port of the server. +* **sever_address**: String that identifies IP and port of the server. * **nb_clients**: This defines the number of clients being created. This piece of information is not required by the client, but it helps us partition the original MNIST dataset to make sure that every client is working on unique subsets of both *training* and *test* sets. -Again, we can go deeper and look inside :code:`flwr_example/quickstart-pytorch/client.py`. +Again, we can go deeper and look inside :code:`flwr_example/quickstart-pytorch/client.py`. After going through the argument parsing code at the beginning of our :code:`main` function, you will find a call to :code:`mnist.load_data`. This function is responsible for partitioning the original MNIST datasets (*training* and *test*) and returning a :code:`torch.utils.data.DataLoader` s for each of them. We then instantiate a :code:`PytorchMNISTClient` object with our client ID, our DataLoaders, the number of epochs in each round, and which device we want to use for training (CPU or GPU). @@ -152,7 +152,7 @@ Now, let's look closely into the :code:`PytorchMNISTClient` inside :code:`flwr_e Parameters ---------- - weights: fl.common.NDArrays + weights: fl.common.NDArrays Weights received by the server and set to local model @@ -179,8 +179,8 @@ Now, let's look closely into the :code:`PytorchMNISTClient` inside :code:`flwr_e Parameters ---------- - ins: fl.common.FitIns - Parameters sent by the server to be used during training. + ins: fl.common.FitIns + Parameters sent by the server to be used during training. Returns ------- @@ -214,9 +214,9 @@ Now, let's look closely into the :code:`PytorchMNISTClient` inside :code:`flwr_e Parameters ---------- - ins: fl.common.EvaluateIns - Parameters sent by the server to be used during testing. - + ins: fl.common.EvaluateIns + Parameters sent by the server to be used during testing. + Returns ------- @@ -262,9 +262,9 @@ The code for the CNN is available under :code:`quickstart-pytorch.mnist` and it Parameters ---------- - x: Tensor + x: Tensor Mini-batch of shape (N,28,28) containing images from MNIST dataset. - + Returns ------- @@ -287,7 +287,7 @@ The code for the CNN is available under :code:`quickstart-pytorch.mnist` and it return output -The second thing to notice is that :code:`PytorchMNISTClient` class inherits from the :code:`fl.client.Client`, and hence it must implement the following methods: +The second thing to notice is that :code:`PytorchMNISTClient` class inherits from the :code:`fl.client.Client`, and hence it must implement the following methods: .. code-block:: python @@ -312,7 +312,7 @@ The second thing to notice is that :code:`PytorchMNISTClient` class inherits fro """Evaluate the provided weights using the locally held dataset.""" -When comparing the abstract class to its derived class :code:`PytorchMNISTClient` you will notice that :code:`fit` calls a :code:`train` function and that :code:`evaluate` calls a :code:`test`: function. +When comparing the abstract class to its derived class :code:`PytorchMNISTClient` you will notice that :code:`fit` calls a :code:`train` function and that :code:`evaluate` calls a :code:`test`: function. These functions can both be found inside the same :code:`quickstart-pytorch.mnist` module: @@ -330,21 +330,21 @@ These functions can both be found inside the same :code:`quickstart-pytorch.mnis ---------- model: torch.nn.ModuleList Neural network model used in this example. - + train_loader: torch.utils.data.DataLoader - DataLoader used in traning. - - epochs: int - Number of epochs to run in each round. - - device: torch.device + DataLoader used in training. + + epochs: int + Number of epochs to run in each round. + + device: torch.device (Default value = torch.device("cpu")) Device where the network will be trained within a client. Returns ------- num_examples_train: int - Number of total samples used during traning. + Number of total samples used during training. """ model.train() @@ -399,10 +399,10 @@ These functions can both be found inside the same :code:`quickstart-pytorch.mnis ---------- model: torch.nn.ModuleList : Neural network model used in this example. - + test_loader: torch.utils.data.DataLoader : DataLoader used in test. - + device: torch.device : (Default value = torch.device("cpu")) Device where the network will be tested within a client. @@ -435,19 +435,19 @@ These functions can both be found inside the same :code:`quickstart-pytorch.mnis Observe that these functions encapsulate regular training and test loops and provide :code:`fit` and :code:`evaluate` with final statistics for each round. -You could substitute them with your custom train and test loops and change the network architecture, and the entire example would still work flawlessly. -As a matter of fact, why not try and modify the code to an example of your liking? +You could substitute them with your custom train and test loops and change the network architecture, and the entire example would still work flawlessly. +As a matter of fact, why not try and modify the code to an example of your liking? Give It a Try ------------- -Looking through the quickstart code description above will have given a good understanding of how *clients* and *servers* work in Flower, how to run a simple experiment, and the internals of a client wrapper. +Looking through the quickstart code description above will have given a good understanding of how *clients* and *servers* work in Flower, how to run a simple experiment, and the internals of a client wrapper. Here are a few things you could try on your own and get more experience with Flower: - Try and change :code:`PytorchMNISTClient` so it can accept different architectures. - Modify the :code:`train` function so that it accepts different optimizers - Modify the :code:`test` function so that it proves not only the top-1 (regular accuracy) but also the top-5 accuracy? -- Go larger! Try to adapt the code to larger images and datasets. Why not try training on ImageNet with a ResNet-50? +- Go larger! Try to adapt the code to larger images and datasets. Why not try training on ImageNet with a ResNet-50? You are ready now. Enjoy learning in a federated way! diff --git a/doc/source/explanation-differential-privacy.rst b/doc/source/explanation-differential-privacy.rst index fcebd81983cc..69fd333f9b13 100644 --- a/doc/source/explanation-differential-privacy.rst +++ b/doc/source/explanation-differential-privacy.rst @@ -1,100 +1,139 @@ -Differential privacy +Differential Privacy ==================== +The information in datasets like healthcare, financial transactions, user preferences, etc., is valuable and has the potential for scientific breakthroughs and provides important business insights. +However, such data is also sensitive and there is a risk of compromising individual privacy. -Flower provides differential privacy (DP) wrapper classes for the easy integration of the central DP guarantees provided by DP-FedAvg into training pipelines defined in any of the various ML frameworks that Flower is compatible with. +Traditional methods like anonymization alone would not work because of attacks like Re-identification and Data Linkage. +That's where differential privacy comes in. It provides the possibility of analyzing data while ensuring the privacy of individuals. -.. warning:: - Please note that these components are still experimental; the correct configuration of DP for a specific task is still an unsolved problem. -.. note:: - The name DP-FedAvg is misleading since it can be applied on top of any FL algorithm that conforms to the general structure prescribed by the FedOpt family of algorithms. +Differential Privacy +-------------------- +Imagine two datasets that are identical except for a single record (for instance, Alice's data). +Differential Privacy (DP) guarantees that any analysis (M), like calculating the average income, will produce nearly identical results for both datasets (O and O' would be similar). +This preserves group patterns while obscuring individual details, ensuring the individual's information remains hidden in the crowd. -DP-FedAvg ---------- +.. image:: ./_static/DP/dp-intro.png + :align: center + :width: 400 + :alt: DP Intro -DP-FedAvg, originally proposed by McMahan et al. [mcmahan]_ and extended by Andrew et al. [andrew]_, is essentially FedAvg with the following modifications. -* **Clipping** : The influence of each client's update is bounded by clipping it. This is achieved by enforcing a cap on the L2 norm of the update, scaling it down if needed. -* **Noising** : Gaussian noise, calibrated to the clipping threshold, is added to the average computed at the server. +One of the most commonly used mechanisms to achieve DP is adding enough noise to the output of the analysis to mask the contribution of each individual in the data while preserving the overall accuracy of the analysis. -The distribution of the update norm has been shown to vary from task-to-task and to evolve as training progresses. This variability is crucial in understanding its impact on differential privacy guarantees, emphasizing the need for an adaptive approach [andrew]_ that continuously adjusts the clipping threshold to track a prespecified quantile of the update norm distribution. +Formal Definition +~~~~~~~~~~~~~~~~~ +Differential Privacy (DP) provides statistical guarantees against the information an adversary can infer through the output of a randomized algorithm. +It provides an unconditional upper bound on the influence of a single individual on the output of the algorithm by adding noise [1]. +A randomized mechanism +M provides (:math:`\epsilon`, :math:`\delta`)-differential privacy if for any two neighboring databases, D :sub:`1` and D :sub:`2`, that differ in only a single record, +and for all possible outputs S ⊆ Range(A): -Simplifying Assumptions -*********************** +.. math:: -We make (and attempt to enforce) a number of assumptions that must be satisfied to ensure that the training process actually realizes the :math:`(\epsilon, \delta)` guarantees the user has in mind when configuring the setup. + \small + P[M(D_{1} \in A)] \leq e^{\delta} P[M(D_{2} \in A)] + \delta -* **Fixed-size subsampling** :Fixed-size subsamples of the clients must be taken at each round, as opposed to variable-sized Poisson subsamples. -* **Unweighted averaging** : The contributions from all the clients must weighted equally in the aggregate to eliminate the requirement for the server to know in advance the sum of the weights of all clients available for selection. -* **No client failures** : The set of available clients must stay constant across all rounds of training. In other words, clients cannot drop out or fail. -The first two are useful for eliminating a multitude of complications associated with calibrating the noise to the clipping threshold, while the third one is required to comply with the assumptions of the privacy analysis. +The :math:`\epsilon` parameter, also known as the privacy budget, is a metric of privacy loss. +It also controls the privacy-utility trade-off; lower :math:`\epsilon` values indicate higher levels of privacy but are likely to reduce utility as well. +The :math:`\delta` parameter accounts for a small probability on which the upper bound :math:`\epsilon` does not hold. +The amount of noise needed to achieve differential privacy is proportional to the sensitivity of the output, which measures the maximum change in the output due to the inclusion or removal of a single record. -.. note:: - These restrictions are in line with constraints imposed by Andrew et al. [andrew]_. -Customizable Responsibility for Noise injection -*********************************************** -In contrast to other implementations where the addition of noise is performed at the server, you can configure the site of noise injection to better match your threat model. We provide users with the flexibility to set up the training such that each client independently adds a small amount of noise to the clipped update, with the result that simply aggregating the noisy updates is equivalent to the explicit addition of noise to the non-noisy aggregate at the server. +Differential Privacy in Machine Learning +---------------------------------------- +DP can be utilized in machine learning to preserve the privacy of the training data. +Differentially private machine learning algorithms are designed in a way to prevent the algorithm to learn any specific information about any individual data points and subsequently prevent the model from revealing sensitive information. +Depending on the stage at which noise is introduced, various methods exist for applying DP to machine learning algorithms. +One approach involves adding noise to the training data (either to the features or labels), while another method entails injecting noise into the gradients of the loss function during model training. +Additionally, such noise can be incorporated into the model's output. +Differential Privacy in Federated Learning +------------------------------------------ +Federated learning is a data minimization approach that allows multiple parties to collaboratively train a model without sharing their raw data. +However, federated learning also introduces new privacy challenges. The model updates between parties and the central server can leak information about the local data. +These leaks can be exploited by attacks such as membership inference and property inference attacks, or model inversion attacks. -To be precise, if we let :math:`m` be the number of clients sampled each round and :math:`\sigma_\Delta` be the scale of the total Gaussian noise that needs to be added to the sum of the model updates, we can use simple maths to show that this is equivalent to each client adding noise with scale :math:`\sigma_\Delta/\sqrt{m}`. +DP can play a crucial role in federated learning to provide privacy for the clients' data. -Wrapper-based approach ----------------------- +Depending on the granularity of privacy provision or the location of noise addition, different forms of DP exist in federated learning. +In this explainer, we focus on two approaches of DP utilization in federated learning based on where the noise is added: at the server (also known as the center) or at the client (also known as the local). -Introducing DP to an existing workload can be thought of as adding an extra layer of security around it. This inspired us to provide the additional server and client-side logic needed to make the training process differentially private as wrappers for instances of the :code:`Strategy` and :code:`NumPyClient` abstract classes respectively. This wrapper-based approach has the advantage of being easily composable with other wrappers that someone might contribute to the Flower library in the future, e.g., for secure aggregation. Using Inheritance instead can be tedious because that would require the creation of new sub- classes every time a new class implementing :code:`Strategy` or :code:`NumPyClient` is defined. +- **Central Differential Privacy**: DP is applied by the server and the goal is to prevent the aggregated model from leaking information about each client's data. -Server-side logic -***************** +- **Local Differential Privacy**: DP is applied on the client side before sending any information to the server and the goal is to prevent the updates that are sent to the server from leaking any information about the client's data. -The first version of our solution was to define a decorator whose constructor accepted, among other things, a boolean-valued variable indicating whether adaptive clipping was to be enabled or not. We quickly realized that this would clutter its :code:`__init__()` function with variables corresponding to hyperparameters of adaptive clipping that would remain unused when it was disabled. A cleaner implementation could be achieved by splitting the functionality into two decorators, :code:`DPFedAvgFixed` and :code:`DPFedAvgAdaptive`, with the latter sub- classing the former. The constructors for both classes accept a boolean parameter :code:`server_side_noising`, which, as the name suggests, determines where noising is to be performed. +Central Differential Privacy +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +In this approach, which is also known as user-level DP, the central server is responsible for adding noise to the globally aggregated parameters. It should be noted that trust in the server is required. -DPFedAvgFixed -::::::::::::: +.. image:: ./_static/DP/CDP.png + :align: center + :width: 400 + :alt: Central Differential Privacy -The server-side capabilities required for the original version of DP-FedAvg, i.e., the one which performed fixed clipping, can be completely captured with the help of wrapper logic for just the following two methods of the :code:`Strategy` abstract class. +While there are various ways to implement central DP in federated learning, we concentrate on the algorithms proposed by [2] and [3]. +The overall approach is to clip the model updates sent by the clients and add some amount of noise to the aggregated model. +In each iteration, a random set of clients is chosen with a specific probability for training. +Each client performs local training on its own data. +The update of each client is then clipped by some value `S` (sensitivity `S`). +This would limit the impact of any individual client which is crucial for privacy and often beneficial for robustness. +A common approach to achieve this is by restricting the `L2` norm of the clients' model updates, ensuring that larger updates are scaled down to fit within the norm `S`. -#. :code:`configure_fit()` : The config dictionary being sent by the wrapped :code:`Strategy` to each client needs to be augmented with an additional value equal to the clipping threshold (keyed under :code:`dpfedavg_clip_norm`) and, if :code:`server_side_noising=true`, another one equal to the scale of the Gaussian noise that needs to be added at the client (keyed under :code:`dpfedavg_noise_stddev`). This entails *post*-processing of the results returned by the wrappee's implementation of :code:`configure_fit()`. -#. :code:`aggregate_fit()`: We check whether any of the sampled clients dropped out or failed to upload an update before the round timed out. In that case, we need to abort the current round, discarding any successful updates that were received, and move on to the next one. On the other hand, if all clients responded successfully, we must force the averaging of the updates to happen in an unweighted manner by intercepting the :code:`parameters` field of :code:`FitRes` for each received update and setting it to 1. Furthermore, if :code:`server_side_noising=true`, each update is perturbed with an amount of noise equal to what it would have been subjected to had client-side noising being enabled. This entails *pre*-processing of the arguments to this method before passing them on to the wrappee's implementation of :code:`aggregate_fit()`. +.. image:: ./_static/DP/clipping.png + :align: center + :width: 300 + :alt: clipping -.. note:: - We can't directly change the aggregation function of the wrapped strategy to force it to add noise to the aggregate, hence we simulate client-side noising to implement server-side noising. +Afterwards, the Gaussian mechanism is used to add noise in order to distort the sum of all clients' updates. +The amount of noise is scaled to the sensitivity value to obtain a privacy guarantee. +The Gaussian mechanism is used with a noise sampled from `N (0, σ²)` where `σ = ( noise_scale * S ) / (number of sampled clients)`. -These changes have been put together into a class called :code:`DPFedAvgFixed`, whose constructor accepts the strategy being decorated, the clipping threshold and the number of clients sampled every round as compulsory arguments. The user is expected to specify the clipping threshold since the order of magnitude of the update norms is highly dependent on the model being trained and providing a default value would be misleading. The number of clients sampled at every round is required to calculate the amount of noise that must be added to each individual update, either by the server or the clients. +Clipping +^^^^^^^^ -DPFedAvgAdaptive -:::::::::::::::: +There are two forms of clipping commonly used in Central DP: Fixed Clipping and Adaptive Clipping. -The additional functionality required to facilitate adaptive clipping has been provided in :code:`DPFedAvgAdaptive`, a subclass of :code:`DPFedAvgFixed`. It overrides the above-mentioned methods to do the following. +- **Fixed Clipping** : A predefined fix threshold is set for the magnitude of clients' updates. Any update exceeding this threshold is clipped back to the threshold value. -#. :code:`configure_fit()` : It intercepts the config dict returned by :code:`super.configure_fit()` to add the key-value pair :code:`dpfedavg_adaptive_clip_enabled:True` to it, which the client interprets as an instruction to include an indicator bit (1 if update norm <= clipping threshold, 0 otherwise) in the results returned by it. +- **Adaptive Clipping** : The clipping threshold dynamically adjusts based on the observed update distribution [4]. It means that the clipping value is tuned during the rounds with respect to the quantile of the update norm distribution. -#. :code:`aggregate_fit()` : It follows a call to :code:`super.aggregate_fit()` with one to :code:`__update_clip_norm__()`, a procedure which adjusts the clipping threshold on the basis of the indicator bits received from the sampled clients. +The choice between fixed and adaptive clipping depends on various factors such as privacy requirements, data distribution, model complexity, and others. +Local Differential Privacy +~~~~~~~~~~~~~~~~~~~~~~~~~~ -Client-side logic -***************** +In this approach, each client is responsible for performing DP. +Local DP avoids the need for a fully trusted aggregator, but it should be noted that local DP leads to a decrease in accuracy but better privacy in comparison to central DP. -The client-side capabilities required can be completely captured through wrapper logic for just the :code:`fit()` method of the :code:`NumPyClient` abstract class. To be precise, we need to *post-process* the update computed by the wrapped client to clip it, if necessary, to the threshold value supplied by the server as part of the config dictionary. In addition to this, it may need to perform some extra work if either (or both) of the following keys are also present in the dict. +.. image:: ./_static/DP/LDP.png + :align: center + :width: 400 + :alt: Local Differential Privacy -* :code:`dpfedavg_noise_stddev` : Generate and add the specified amount of noise to the clipped update. -* :code:`dpfedavg_adaptive_clip_enabled` : Augment the metrics dict in the :code:`FitRes` object being returned to the server with an indicator bit, calculated as described earlier. +In this explainer, we focus on two forms of achieving Local DP: -Performing the :math:`(\epsilon, \delta)` analysis --------------------------------------------------- +- Each client adds noise to the local updates before sending them to the server. To achieve (:math:`\epsilon`, :math:`\delta`)-DP, considering the sensitivity of the local model to be ∆, Gaussian noise is applied with a noise scale of σ where: -Assume you have trained for :math:`n` rounds with sampling fraction :math:`q` and noise multiplier :math:`z`. In order to calculate the :math:`\epsilon` value this would result in for a particular :math:`\delta`, the following script may be used. +.. math:: + \small + \frac{∆ \times \sqrt{2 \times \log\left(\frac{1.25}{\delta}\right)}}{\epsilon} -.. code-block:: python - import tensorflow_privacy as tfp - max_order = 32 - orders = range(2, max_order + 1) - rdp = tfp.compute_rdp_sample_without_replacement(q, z, n, orders) - eps, _, _ = tfp.rdp_accountant.get_privacy_spent(rdp, target_delta=delta) +- Each client adds noise to the gradients of the model during the local training (DP-SGD). More specifically, in this approach, gradients are clipped and an amount of calibrated noise is injected into the gradients. -.. [mcmahan] McMahan et al. "Learning Differentially Private Recurrent Language Models." International Conference on Learning Representations (ICLR), 2017. -.. [andrew] Andrew, Galen, et al. "Differentially Private Learning with Adaptive Clipping." Advances in Neural Information Processing Systems (NeurIPS), 2021. +Please note that these two approaches are providing privacy at different levels. + + +**References:** + +[1] Dwork et al. The Algorithmic Foundations of Differential Privacy. + +[2] McMahan et al. Learning Differentially Private Recurrent Language Models. + +[3] Geyer et al. Differentially Private Federated Learning: A Client Level Perspective. + +[4] Galen et al. Differentially Private Learning with Adaptive Clipping. diff --git a/doc/source/explanation-federated-evaluation.rst b/doc/source/explanation-federated-evaluation.rst index 632241f70d36..bcdca9bae700 100644 --- a/doc/source/explanation-federated-evaluation.rst +++ b/doc/source/explanation-federated-evaluation.rst @@ -120,7 +120,7 @@ Federated evaluation can be configured from the server side. Built-in strategies # Create strategy strategy = fl.server.strategy.FedAvg( - # ... other FedAvg agruments + # ... other FedAvg arguments fraction_evaluate=0.2, min_evaluate_clients=2, min_available_clients=10, diff --git a/doc/source/how-to-configure-clients.rst b/doc/source/how-to-configure-clients.rst index 26c132125ccf..ff0a2f4033df 100644 --- a/doc/source/how-to-configure-clients.rst +++ b/doc/source/how-to-configure-clients.rst @@ -13,7 +13,7 @@ Configuration values are represented as a dictionary with ``str`` keys and value config_dict = { "dropout": True, # str key, bool value "learning_rate": 0.01, # str key, float value - "batch_size": 32, # str key, int value + "batch_size": 32, # str key, int value "optimizer": "sgd", # str key, str value } @@ -56,7 +56,7 @@ To make the built-in strategies use this function, we can pass it to ``FedAvg`` One the client side, we receive the configuration dictionary in ``fit``: .. code-block:: python - + class FlowerClient(flwr.client.NumPyClient): def fit(parameters, config): print(config["batch_size"]) # Prints `32` @@ -86,7 +86,7 @@ Configuring individual clients In some cases, it is necessary to send different configuration values to different clients. -This can be achieved by customizing an existing strategy or by `implementing a custom strategy from scratch `_. Here's a nonsensical example that customizes :code:`FedAvg` by adding a custom ``"hello": "world"`` configuration key/value pair to the config dict of a *single client* (only the first client in the list, the other clients in this round to not receive this "special" config value): +This can be achieved by customizing an existing strategy or by :doc:`implementing a custom strategy from scratch `. Here's a nonsensical example that customizes :code:`FedAvg` by adding a custom ``"hello": "world"`` configuration key/value pair to the config dict of a *single client* (only the first client in the list, the other clients in this round to not receive this "special" config value): .. code-block:: python diff --git a/doc/source/how-to-configure-logging.rst b/doc/source/how-to-configure-logging.rst index 3fcfe85b9592..d5559429a73c 100644 --- a/doc/source/how-to-configure-logging.rst +++ b/doc/source/how-to-configure-logging.rst @@ -129,4 +129,4 @@ Log to a remote service The :code:`fl.common.logger.configure` function, also allows specifying a host to which logs can be pushed (via :code:`POST`) through a native Python :code:`logging.handler.HTTPHandler`. This is a particularly useful feature in :code:`gRPC`-based Federated Learning workloads where otherwise gathering logs from all entities (i.e. the server and the clients) might be cumbersome. -Note that in Flower simulation, the server automatically displays all logs. You can still specify a :code:`HTTPHandler` should you whish to backup or analyze the logs somewhere else. +Note that in Flower simulation, the server automatically displays all logs. You can still specify a :code:`HTTPHandler` should you wish to backup or analyze the logs somewhere else. diff --git a/doc/source/how-to-enable-ssl-connections.rst b/doc/source/how-to-enable-ssl-connections.rst index fa59d4423c5a..051dd5711497 100644 --- a/doc/source/how-to-enable-ssl-connections.rst +++ b/doc/source/how-to-enable-ssl-connections.rst @@ -75,9 +75,9 @@ We are now going to show how to write a client which uses the previously generat client = MyFlowerClient() # Start client - fl.client.start_numpy_client( + fl.client.start_client( "localhost:8080", - client=client, + client=client.to_client(), root_certificates=Path(".cache/certificates/ca.crt").read_bytes(), ) diff --git a/doc/source/how-to-implement-strategies.rst b/doc/source/how-to-implement-strategies.rst index 7997503b65a8..01bbb3042973 100644 --- a/doc/source/how-to-implement-strategies.rst +++ b/doc/source/how-to-implement-strategies.rst @@ -233,7 +233,7 @@ The return value is a list of tuples, each representing the instructions that wi * Use the :code:`client_manager` to randomly sample all (or a subset of) available clients (each represented as a :code:`ClientProxy` object) * Pair each :code:`ClientProxy` with the same :code:`FitIns` holding the current global model :code:`parameters` and :code:`config` dict -More sophisticated implementations can use :code:`configure_fit` to implement custom client selection logic. A client will only participate in a round if the corresponding :code:`ClientProxy` is included in the the list returned from :code:`configure_fit`. +More sophisticated implementations can use :code:`configure_fit` to implement custom client selection logic. A client will only participate in a round if the corresponding :code:`ClientProxy` is included in the list returned from :code:`configure_fit`. .. note:: @@ -280,7 +280,7 @@ The return value is a list of tuples, each representing the instructions that wi * Use the :code:`client_manager` to randomly sample all (or a subset of) available clients (each represented as a :code:`ClientProxy` object) * Pair each :code:`ClientProxy` with the same :code:`EvaluateIns` holding the current global model :code:`parameters` and :code:`config` dict -More sophisticated implementations can use :code:`configure_evaluate` to implement custom client selection logic. A client will only participate in a round if the corresponding :code:`ClientProxy` is included in the the list returned from :code:`configure_evaluate`. +More sophisticated implementations can use :code:`configure_evaluate` to implement custom client selection logic. A client will only participate in a round if the corresponding :code:`ClientProxy` is included in the list returned from :code:`configure_evaluate`. .. note:: diff --git a/doc/source/how-to-install-flower.rst b/doc/source/how-to-install-flower.rst index ff3dbb605846..aebe5f7316de 100644 --- a/doc/source/how-to-install-flower.rst +++ b/doc/source/how-to-install-flower.rst @@ -57,7 +57,7 @@ Advanced installation options Install via Docker ~~~~~~~~~~~~~~~~~~ -`How to run Flower using Docker `_ +:doc:`How to run Flower using Docker ` Install pre-release ~~~~~~~~~~~~~~~~~~~ diff --git a/doc/source/how-to-monitor-simulation.rst b/doc/source/how-to-monitor-simulation.rst index 740004914eed..f6c26a701d94 100644 --- a/doc/source/how-to-monitor-simulation.rst +++ b/doc/source/how-to-monitor-simulation.rst @@ -161,7 +161,7 @@ However, you can overwrite the defaults. When starting a simulation, do the foll ram_memory = 16_000 * 1024 * 1024 # 16 GB fl.simulation.start_simulation( # ... - # all the args you were specyfing before + # all the args you were specifying before # ... ray_init_args = { "include_dashboard": True, # we need this one for tracking @@ -187,7 +187,7 @@ Let’s also specify the resource for a single client. fl.simulation.start_simulation( # ... - # all the args you were specyfing before + # all the args you were specifying before # ... ray_init_args = { "include_dashboard": True, # we need this one for tracking @@ -231,6 +231,6 @@ A: Either the simulation has already finished, or you still need to start Promet Resources --------- -Ray Dashboard: ``_ +Ray Dashboard: ``_ -Ray Metrics: ``_ +Ray Metrics: ``_ diff --git a/doc/source/how-to-run-flower-using-docker.rst b/doc/source/how-to-run-flower-using-docker.rst index 27ff61c280cb..ed034c820142 100644 --- a/doc/source/how-to-run-flower-using-docker.rst +++ b/doc/source/how-to-run-flower-using-docker.rst @@ -1,5 +1,5 @@ Run Flower using Docker -==================== +======================= The simplest way to get started with Flower is by using the pre-made Docker images, which you can find on `Docker Hub `_. @@ -31,12 +31,12 @@ If you're looking to try out Flower, you can use the following command: .. code-block:: bash - $ docker run --rm -p 9091:9091 -p 9092:9092 flwr/server:1.6.0-py3.11-ubuntu22.04 \ + $ docker run --rm -p 9091:9091 -p 9092:9092 flwr/server:1.7.0-py3.11-ubuntu22.04 \ --insecure -The command will pull the Docker image with the tag ``1.6.0-py3.11-ubuntu22.04`` from Docker Hub. +The command will pull the Docker image with the tag ``1.7.0-py3.11-ubuntu22.04`` from Docker Hub. The tag contains the information which Flower, Python and Ubuntu is used. In this case, it -uses Flower 1.6.0, Python 3.11 and Ubuntu 22.04. The ``--rm`` flag tells Docker to remove +uses Flower 1.7.0, Python 3.11 and Ubuntu 22.04. The ``--rm`` flag tells Docker to remove the container after it exits. .. note:: @@ -54,14 +54,14 @@ to the Flower server. Here, we are passing the flag ``--insecure``. The ``--insecure`` flag enables insecure communication (using HTTP, not HTTPS) and should only be used for testing purposes. We strongly recommend enabling - `SSL `_ + `SSL `_ when deploying to a production environment. You can use ``--help`` to view all available flags that the server supports: .. code-block:: bash - $ docker run --rm flwr/server:1.6.0-py3.11-ubuntu22.04 --help + $ docker run --rm flwr/server:1.7.0-py3.11-ubuntu22.04 --help Mounting a volume to store the state on the host system ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -75,7 +75,7 @@ flag ``--database`` to specify the name of the database file. .. code-block:: bash $ docker run --rm \ - -p 9091:9091 -p 9092:9092 -v ~/:/app/ flwr/server:1.6.0-py3.11-ubuntu22.04 \ + -p 9091:9091 -p 9092:9092 -v ~/:/app/ flwr/server:1.7.0-py3.11-ubuntu22.04 \ --insecure \ --database state.db @@ -90,7 +90,7 @@ To enable SSL, you will need a CA certificate, a server certificate and a server .. note:: For testing purposes, you can generate your own self-signed certificates. The - `Enable SSL connections `_ + `Enable SSL connections `_ page contains a section that will guide you through the process. Assuming all files we need are in the local ``certificates`` directory, we can use the flag @@ -101,7 +101,7 @@ the server with the ``--certificates`` flag. .. code-block:: bash $ docker run --rm \ - -p 9091:9091 -p 9092:9092 -v ./certificates/:/app/ flwr/server:1.6.0-py3.11-ubuntu22.04 \ + -p 9091:9091 -p 9092:9092 -v ./certificates/:/app/ flwr/server:1.7.0-py3.11-ubuntu22.04 \ --certificates ca.crt server.pem server.key Using a different Flower or Python version @@ -118,19 +118,19 @@ updates of system dependencies that should not change the functionality of Flowe want to ensure that you always use the same image, you can specify the hash of the image instead of the tag. -The following command returns the current image hash referenced by the ``server:1.6.0-py3.11-ubuntu22.04`` tag: +The following command returns the current image hash referenced by the ``server:1.7.0-py3.11-ubuntu22.04`` tag: .. code-block:: bash - $ docker inspect --format='{{index .RepoDigests 0}}' flwr/server:1.6.0-py3.11-ubuntu22.04 - flwr/server@sha256:43fc389bcb016feab2b751b2ccafc9e9a906bb0885bd92b972329801086bc017 + $ docker inspect --format='{{index .RepoDigests 0}}' flwr/server:1.7.0-py3.11-ubuntu22.04 + flwr/server@sha256:c4be5012f9d73e3022e98735a889a463bb2f4f434448ebc19c61379920b1b327 Next, we can pin the hash when running a new server container: .. code-block:: bash $ docker run \ - --rm flwr/server@sha256:43fc389bcb016feab2b751b2ccafc9e9a906bb0885bd92b972329801086bc017 \ + --rm flwr/server@sha256:c4be5012f9d73e3022e98735a889a463bb2f4f434448ebc19c61379920b1b327 \ --insecure Setting environment variables @@ -141,4 +141,4 @@ To set a variable inside a Docker container, you can use the ``-e = .. code-block:: bash $ docker run -e FLWR_TELEMETRY_ENABLED=0 \ - --rm flwr/server:1.6.0-py3.11-ubuntu22.04 --insecure + --rm flwr/server:1.7.0-py3.11-ubuntu22.04 --insecure diff --git a/doc/source/how-to-run-simulations.rst b/doc/source/how-to-run-simulations.rst index 707e3d3ffe84..d1dcb511ed51 100644 --- a/doc/source/how-to-run-simulations.rst +++ b/doc/source/how-to-run-simulations.rst @@ -7,7 +7,7 @@ Run simulations Simulating Federated Learning workloads is useful for a multitude of use-cases: you might want to run your workload on a large cohort of clients but without having to source, configure and mange a large number of physical devices; you might want to run your FL workloads as fast as possible on the compute systems you have access to without having to go through a complex setup process; you might want to validate your algorithm on different scenarios at varying levels of data and system heterogeneity, client availability, privacy budgets, etc. These are among some of the use-cases where simulating FL workloads makes sense. Flower can accommodate these scenarios by means of its `VirtualClientEngine `_ or VCE. -The :code:`VirtualClientEngine` schedules, launches and manages `virtual` clients. These clients are identical to `non-virtual` clients (i.e. the ones you launch via the command `flwr.client.start_numpy_client `_) in the sense that they can be configure by creating a class inheriting, for example, from `flwr.client.NumPyClient `_ and therefore behave in an identical way. In addition to that, clients managed by the :code:`VirtualClientEngine` are: +The :code:`VirtualClientEngine` schedules, launches and manages `virtual` clients. These clients are identical to `non-virtual` clients (i.e. the ones you launch via the command `flwr.client.start_client `_) in the sense that they can be configure by creating a class inheriting, for example, from `flwr.client.NumPyClient `_ and therefore behave in an identical way. In addition to that, clients managed by the :code:`VirtualClientEngine` are: * resource-aware: this means that each client gets assigned a portion of the compute and memory on your system. You as a user can control this at the beginning of the simulation and allows you to control the degree of parallelism of your Flower FL simulation. The fewer the resources per client, the more clients can run concurrently on the same hardware. * self-managed: this means that you as a user do not need to launch clients manually, instead this gets delegated to :code:`VirtualClientEngine`'s internals. @@ -29,7 +29,7 @@ Running Flower simulations still require you to define your client class, a stra def client_fn(cid: str): # Return a standard Flower client - return MyFlowerClient() + return MyFlowerClient().to_client() # Launch the simulation hist = fl.simulation.start_simulation( diff --git a/doc/source/how-to-upgrade-to-flower-1.0.rst b/doc/source/how-to-upgrade-to-flower-1.0.rst index fd380e95d69c..3a55a1a953f5 100644 --- a/doc/source/how-to-upgrade-to-flower-1.0.rst +++ b/doc/source/how-to-upgrade-to-flower-1.0.rst @@ -50,7 +50,7 @@ Strategies / ``start_server`` / ``start_simulation`` - Replace ``num_rounds=1`` in ``start_simulation`` with the new ``config=ServerConfig(...)`` (see previous item) - Remove ``force_final_distributed_eval`` parameter from calls to ``start_server``. Distributed evaluation on all clients can be enabled by configuring the strategy to sample all clients for evaluation after the last round of training. - Rename parameter/ndarray conversion functions: - + - ``parameters_to_weights`` --> ``parameters_to_ndarrays`` - ``weights_to_parameters`` --> ``ndarrays_to_parameters`` @@ -81,11 +81,11 @@ Optional improvements Along with the necessary changes above, there are a number of potential improvements that just became possible: -- Remove "placeholder" methods from subclasses of ``Client`` or ``NumPyClient``. If you, for example, use server-side evaluation, then empy placeholder implementations of ``evaluate`` are no longer necessary. +- Remove "placeholder" methods from subclasses of ``Client`` or ``NumPyClient``. If you, for example, use server-side evaluation, then empty placeholder implementations of ``evaluate`` are no longer necessary. - Configure the round timeout via ``start_simulation``: ``start_simulation(..., config=flwr.server.ServerConfig(num_rounds=3, round_timeout=600.0), ...)`` Further help ------------ -Most official `Flower code examples `_ are already updated to Flower 1.0, they can serve as a reference for using the Flower 1.0 API. If there are further questionsm, `join the Flower Slack `_ and use the channgel ``#questions``. +Most official `Flower code examples `_ are already updated to Flower 1.0, they can serve as a reference for using the Flower 1.0 API. If there are further questions, `join the Flower Slack `_ and use the channel ``#questions``. diff --git a/doc/source/how-to-use-built-in-middleware-layers.rst b/doc/source/how-to-use-built-in-middleware-layers.rst deleted file mode 100644 index 2e91623b26be..000000000000 --- a/doc/source/how-to-use-built-in-middleware-layers.rst +++ /dev/null @@ -1,87 +0,0 @@ -Use Built-in Middleware Layers -============================== - -**Note: This tutorial covers experimental features. The functionality and interfaces may change in future versions.** - -In this tutorial, we will learn how to utilize built-in middleware layers to augment the behavior of a ``FlowerCallable``. Middleware allows us to perform operations before and after a task is processed in the ``FlowerCallable``. - -What is middleware? -------------------- - -Middleware is a callable that wraps around a ``FlowerCallable``. It can manipulate or inspect incoming tasks (``TaskIns``) in the ``Fwd`` and the resulting tasks (``TaskRes``) in the ``Bwd``. The signature for a middleware layer (``Layer``) is as follows: - -.. code-block:: python - - FlowerCallable = Callable[[Fwd], Bwd] - Layer = Callable[[Fwd, FlowerCallable], Bwd] - -A typical middleware function might look something like this: - -.. code-block:: python - - def example_middleware(fwd: Fwd, ffn: FlowerCallable) -> Bwd: - # Do something with Fwd before passing to the inner ``FlowerCallable``. - bwd = ffn(fwd) - # Do something with Bwd before returning. - return bwd - -Using middleware layers ------------------------ - -To use middleware layers in your ``FlowerCallable``, you can follow these steps: - -1. Import the required middleware -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -First, import the built-in middleware layers you intend to use: - -.. code-block:: python - - import flwr as fl - from flwr.client.middleware import example_middleware1, example_middleware2 - -2. Define your client function -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Define your client function (``client_fn``) that will be wrapped by the middleware: - -.. code-block:: python - - def client_fn(cid): - # Your client code goes here. - return # your client - -3. Create the ``FlowerCallable`` with middleware -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Create your ``FlowerCallable`` and pass the middleware layers as a list to the ``layers`` argument. The order in which you provide the middleware layers matters: - -.. code-block:: python - - flower = fl.app.Flower( - client_fn=client_fn, - layers=[ - example_middleware1, # Middleware layer 1 - example_middleware2, # Middleware layer 2 - ] - ) - -Order of execution ------------------- - -When the ``FlowerCallable`` runs, the middleware layers are executed in the order they are provided in the list: - -1. ``example_middleware1`` (outermost layer) -2. ``example_middleware2`` (next layer) -3. Message handler (core function that handles ``TaskIns`` and returns ``TaskRes``) -4. ``example_middleware2`` (on the way back) -5. ``example_middleware1`` (outermost layer on the way back) - -Each middleware has a chance to inspect and modify the ``TaskIns`` in the ``Fwd`` before passing it to the next layer, and likewise with the ``TaskRes`` in the ``Bwd`` before returning it up the stack. - -Conclusion ----------- - -By following this guide, you have learned how to effectively use middleware layers to enhance your ``FlowerCallable``'s functionality. Remember that the order of middleware is crucial and affects how the input and output are processed. - -Enjoy building more robust and flexible ``FlowerCallable``s with middleware layers! diff --git a/doc/source/how-to-use-built-in-mods.rst b/doc/source/how-to-use-built-in-mods.rst new file mode 100644 index 000000000000..341139175074 --- /dev/null +++ b/doc/source/how-to-use-built-in-mods.rst @@ -0,0 +1,89 @@ +Use Built-in Mods +================= + +**Note: This tutorial covers experimental features. The functionality and interfaces may change in future versions.** + +In this tutorial, we will learn how to utilize built-in mods to augment the behavior of a ``ClientApp``. Mods (sometimes also called Modifiers) allow us to perform operations before and after a task is processed in the ``ClientApp``. + +What are Mods? +-------------- + +A Mod is a callable that wraps around a ``ClientApp``. It can manipulate or inspect the incoming ``Message`` and the resulting outgoing ``Message``. The signature for a ``Mod`` is as follows: + +.. code-block:: python + + ClientApp = Callable[[Message, Context], Message] + Mod = Callable[[Message, Context, ClientApp], Message] + +A typical mod function might look something like this: + +.. code-block:: python + + def example_mod(msg: Message, ctx: Context, nxt: ClientApp) -> Message: + # Do something with incoming Message (or Context) + # before passing to the inner ``ClientApp`` + msg = nxt(msg, ctx) + # Do something with outgoing Message (or Context) + # before returning + return msg + +Using Mods +---------- + +To use mods in your ``ClientApp``, you can follow these steps: + +1. Import the required mods +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +First, import the built-in mod you intend to use: + +.. code-block:: python + + import flwr as fl + from flwr.client.mod import example_mod_1, example_mod_2 + +2. Define your client function +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Define your client function (``client_fn``) that will be wrapped by the mod(s): + +.. code-block:: python + + def client_fn(cid): + # Your client code goes here. + return # your client + +3. Create the ``ClientApp`` with mods +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Create your ``ClientApp`` and pass the mods as a list to the ``mods`` argument. The order in which you provide the mods matters: + +.. code-block:: python + + app = fl.client.ClientApp( + client_fn=client_fn, + mods=[ + example_mod_1, # Mod 1 + example_mod_2, # Mod 2 + ] + ) + +Order of execution +------------------ + +When the ``ClientApp`` runs, the mods are executed in the order they are provided in the list: + +1. ``example_mod_1`` (outermost mod) +2. ``example_mod_2`` (next mod) +3. Message handler (core function that handles the incoming ``Message`` and returns the outgoing ``Message``) +4. ``example_mod_2`` (on the way back) +5. ``example_mod_1`` (outermost mod on the way back) + +Each mod has a chance to inspect and modify the incoming ``Message`` before passing it to the next mod, and likewise with the outgoing ``Message`` before returning it up the stack. + +Conclusion +---------- + +By following this guide, you have learned how to effectively use mods to enhance your ``ClientApp``'s functionality. Remember that the order of mods is crucial and affects how the input and output are processed. + +Enjoy building a more robust and flexible ``ClientApp`` with mods! diff --git a/doc/source/how-to-use-differential-privacy.rst b/doc/source/how-to-use-differential-privacy.rst new file mode 100644 index 000000000000..c8901bd906cc --- /dev/null +++ b/doc/source/how-to-use-differential-privacy.rst @@ -0,0 +1,126 @@ +Use Differential Privacy +------------------------ +This guide explains how you can utilize differential privacy in the Flower framework. If you are not yet familiar with differential privacy, you can refer to :doc:`explanation-differential-privacy`. + +.. warning:: + + Differential Privacy in Flower is in a preview phase. If you plan to use these features in a production environment with sensitive data, feel free contact us to discuss your requirements and to receive guidance on how to best use these features. + + +Central Differential Privacy +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +This approach consists of two seprate phases: clipping of the updates and adding noise to the aggregated model. +For the clipping phase, Flower framework has made it possible to decide whether to perform clipping on the server side or the client side. + +- **Server-side Clipping**: This approach has the advantage of the server enforcing uniform clipping across all clients' updates and reducing the communication overhead for clipping values. However, it also has the disadvantage of increasing the computational load on the server due to the need to perform the clipping operation for all clients. +- **Client-side Clipping**: This approach has the advantage of reducing the computational overhead on the server. However, it also has the disadvantage of lacking centralized control, as the server has less control over the clipping process. + + + +Server-side Clipping +^^^^^^^^^^^^^^^^^^^^ +For central DP with server-side clipping, there are two :code:`Strategy` classes that act as wrappers around the actual :code:`Strategy` instance (for example, :code:`FedAvg`). +The two wrapper classes are :code:`DifferentialPrivacyServerSideFixedClipping` and :code:`DifferentialPrivacyServerSideAdaptiveClipping` for fixed and adaptive clipping. + +.. image:: ./_static/DP/serversideCDP.png + :align: center + :width: 700 + :alt: server side clipping + + +The code sample below enables the :code:`FedAvg` strategy to use server-side fixed clipping using the :code:`DifferentialPrivacyServerSideFixedClipping` wrapper class. +The same approach can be used with :code:`DifferentialPrivacyServerSideAdaptiveClipping` by adjusting the corresponding input parameters. + +.. code-block:: python + + from flwr.server.strategy import DifferentialPrivacyClientSideFixedClipping + + # Create the strategy + strategy = fl.server.strategy.FedAvg(...) + + # Wrap the strategy with the DifferentialPrivacyServerSideFixedClipping wrapper + dp_strategy = DifferentialPrivacyServerSideFixedClipping( + strategy, + cfg.noise_multiplier, + cfg.clipping_norm, + cfg.num_sampled_clients, + ) + + + +Client-side Clipping +^^^^^^^^^^^^^^^^^^^^ +For central DP with client-side clipping, the server sends the clipping value to selected clients on each round. +Clients can use existing Flower :code:`Mods` to perform the clipping. +Two mods are available for fixed and adaptive client-side clipping: :code:`fixedclipping_mod` and :code:`adaptiveclipping_mod` with corresponding server-side wrappers :code:`DifferentialPrivacyClientSideFixedClipping` and :code:`DifferentialPrivacyClientSideAdaptiveClipping`. + +.. image:: ./_static/DP/clientsideCDP.png + :align: center + :width: 800 + :alt: client side clipping + + +The code sample below enables the :code:`FedAvg` strategy to use differential privacy with client-side fixed clipping using both the :code:`DifferentialPrivacyClientSideFixedClipping` wrapper class and, on the client, :code:`fixedclipping_mod`: + +.. code-block:: python + + from flwr.server.strategy import DifferentialPrivacyClientSideFixedClipping + + # Create the strategy + strategy = fl.server.strategy.FedAvg(...) + + # Wrap the strategy with the DifferentialPrivacyClientSideFixedClipping wrapper + dp_strategy = DifferentialPrivacyClientSideFixedClipping( + strategy, + cfg.noise_multiplier, + cfg.clipping_norm, + cfg.num_sampled_clients, + ) + +In addition to the server-side strategy wrapper, the :code:`ClientApp` needs to configure the matching :code:`fixedclipping_mod` to perform the client-side clipping: + +.. code-block:: python + + from flwr.client.mod import fixedclipping_mod + + # Add fixedclipping_mod to the client-side mods + app = fl.client.ClientApp( + client_fn=client_fn, + mods=[ + fixedclipping_mod, + ] + ) + + +Local Differential Privacy +~~~~~~~~~~~~~~~~~~~~~~~~~~ +To utilize local differential privacy (DP) and add noise to the client model parameters before transmitting them to the server in Flower, you can use the `LocalDpMod`. The following hyperparameters need to be set: clipping norm value, sensitivity, epsilon, and delta. + +.. image:: ./_static/DP/localdp.png + :align: center + :width: 700 + :alt: local DP mod + +Below is a code example that shows how to use :code:`LocalDpMod`: + +.. code-block:: python + + from flwr.client.mod.localdp_mod import LocalDpMod + + # Create an instance of the mod with the required params + local_dp_obj = LocalDpMod( + cfg.clipping_norm, cfg.sensitivity, cfg.epsilon, cfg.delta + ) + # Add local_dp_obj to the client-side mods + + app = fl.client.ClientApp( + client_fn=client_fn, + mods=[local_dp_obj], + ) + + +Please note that the order of mods, especially those that modify parameters, is important when using multiple modifiers. Typically, differential privacy (DP) modifiers should be the last to operate on parameters. + +Local Training using Privacy Engines +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +For ensuring data instance-level privacy during local model training on the client side, consider leveraging privacy engines such as Opacus and TensorFlow Privacy. For examples of using Flower with these engines, please refer to the Flower examples directory (`Opacus `_, `Tensorflow Privacy `_). \ No newline at end of file diff --git a/doc/source/how-to-use-strategies.rst b/doc/source/how-to-use-strategies.rst index 6d24f97bd7f6..d0e2cd63a091 100644 --- a/doc/source/how-to-use-strategies.rst +++ b/doc/source/how-to-use-strategies.rst @@ -45,7 +45,7 @@ Configuring client fit and client evaluate ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The server can pass new configuration values to the client each round by providing a function to :code:`on_fit_config_fn`. The provided function will be called by the strategy and must return a dictionary of configuration key values pairs that will be sent to the client. -It must return a dictionary of arbitraty configuration values :code:`client.fit` and :code:`client.evaluate` functions during each round of federated learning. +It must return a dictionary of arbitrary configuration values :code:`client.fit` and :code:`client.evaluate` functions during each round of federated learning. .. code-block:: python diff --git a/doc/source/index.rst b/doc/source/index.rst index 5df591d6ce05..894155be03f1 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -4,7 +4,7 @@ Flower Framework Documentation .. meta:: :description: Check out the documentation of the main Flower Framework enabling easy Python development for Federated Learning. -Welcome to Flower's documentation. `Flower `_ is a friendly federated learning framework. +Welcome to Flower's documentation. `Flower `_ is a friendly federated learning framework. Join the Flower Community @@ -12,7 +12,7 @@ Join the Flower Community The Flower Community is growing quickly - we're a friendly group of researchers, engineers, students, professionals, academics, and other enthusiasts. -.. button-link:: https://flower.dev/join-slack +.. button-link:: https://flower.ai/join-slack :color: primary :shadow: @@ -91,8 +91,9 @@ Problem-oriented how-to guides show step-by-step how to achieve a specific goal. how-to-configure-logging how-to-enable-ssl-connections how-to-upgrade-to-flower-1.0 - how-to-use-built-in-middleware-layers + how-to-use-built-in-mods how-to-run-flower-using-docker + how-to-use-differential-privacy .. toctree:: :maxdepth: 1 diff --git a/doc/source/ref-api-cli.rst b/doc/source/ref-api-cli.rst index 039a2ea27cf8..63579143755d 100644 --- a/doc/source/ref-api-cli.rst +++ b/doc/source/ref-api-cli.rst @@ -1,42 +1,52 @@ Flower CLI reference ==================== -.. _flower-server-apiref: +.. _flower-superlink-apiref: -flower-server -~~~~~~~~~~~~~ +flower-superlink +~~~~~~~~~~~~~~~~ .. argparse:: :module: flwr.server.app - :func: _parse_args_server - :prog: flower-server + :func: _parse_args_run_superlink + :prog: flower-superlink -.. _flower-driver-apiref: +.. _flower-driver-api-apiref: flower-driver-api ~~~~~~~~~~~~~~~~~ .. argparse:: :module: flwr.server.app - :func: _parse_args_driver + :func: _parse_args_run_driver_api :prog: flower-driver-api -.. _flower-fleet-apiref: +.. _flower-fleet-api-apiref: flower-fleet-api ~~~~~~~~~~~~~~~~ .. argparse:: :module: flwr.server.app - :func: _parse_args_fleet + :func: _parse_args_run_fleet_api :prog: flower-fleet-api -.. .. _flower-client-apiref: +.. _flower-client-app-apiref: + +flower-client-app +~~~~~~~~~~~~~~~~~ + +.. argparse:: + :module: flwr.client.app + :func: _parse_args_run_client_app + :prog: flower-client-app + +.. _flower-server-app-apiref: -.. flower-client -.. ~~~~~~~~~~~~~ +flower-server-app +~~~~~~~~~~~~~~~~~ - .. argparse:: -.. :filename: flwr.client -.. :func: run_client -.. :prog: flower-client +.. argparse:: + :module: flwr.server.run_serverapp + :func: _parse_args_run_server_app + :prog: flower-server-app diff --git a/doc/source/ref-changelog.md b/doc/source/ref-changelog.md index 5f323bc80baa..1a6524d29353 100644 --- a/doc/source/ref-changelog.md +++ b/doc/source/ref-changelog.md @@ -2,23 +2,102 @@ ## Unreleased -- **Add scikit-learn tabular data example** ([#2719](https://github.com/adap/flower/pull/2719)) +### What's new? + +### Incompatible changes + +## v1.7.0 (2024-02-05) + +### Thanks to our contributors + +We would like to give our special thanks to all the contributors who made the new version of Flower possible (in `git shortlog` order): + +`Aasheesh Singh`, `Adam Narozniak`, `Aml Hassan Esmil`, `Charles Beauville`, `Daniel J. Beutel`, `Daniel Nata Nugraha`, `Edoardo Gabrielli`, `Gustavo Bertoli`, `HelinLin`, `Heng Pan`, `Javier`, `M S Chaitanya Kumar`, `Mohammad Naseri`, `Nikos Vlachakis`, `Pritam Neog`, `Robert Kuska`, `Robert Steiner`, `Taner Topal`, `Yahia Salaheldin Shaaban`, `Yan Gao`, `Yasar Abbas` + +### What's new? + +- **Introduce stateful clients (experimental)** ([#2770](https://github.com/adap/flower/pull/2770), [#2686](https://github.com/adap/flower/pull/2686), [#2696](https://github.com/adap/flower/pull/2696), [#2643](https://github.com/adap/flower/pull/2643), [#2769](https://github.com/adap/flower/pull/2769)) + + Subclasses of `Client` and `NumPyClient` can now store local state that remains on the client. Let's start with the highlight first: this new feature is compatible with both simulated clients (via `start_simulation`) and networked clients (via `start_client`). It's also the first preview of new abstractions like `Context` and `RecordSet`. Clients can access state of type `RecordSet` via `state: RecordSet = self.context.state`. Changes to this `RecordSet` are preserved across different rounds of execution to enable stateful computations in a unified way across simulation and deployment. + +- **Improve performance** ([#2293](https://github.com/adap/flower/pull/2293)) + + Flower is faster than ever. All `FedAvg`-derived strategies now use in-place aggregation to reduce memory consumption. The Flower client serialization/deserialization has been rewritten from the ground up, which results in significant speedups, especially when the client-side training time is short. + +- **Support Federated Learning with Apple MLX and Flower** ([#2693](https://github.com/adap/flower/pull/2693)) + + Flower has official support for federated learning using [Apple MLX](https://ml-explore.github.io/mlx) via the new `quickstart-mlx` code example. + +- **Introduce new XGBoost cyclic strategy** ([#2666](https://github.com/adap/flower/pull/2666), [#2668](https://github.com/adap/flower/pull/2668)) + + A new strategy called `FedXgbCyclic` supports a client-by-client style of training (often called cyclic). The `xgboost-comprehensive` code example shows how to use it in a full project. In addition to that, `xgboost-comprehensive` now also supports simulation mode. With this, Flower offers best-in-class XGBoost support. + +- **Support Python 3.11** ([#2394](https://github.com/adap/flower/pull/2394)) + + Framework tests now run on Python 3.8, 3.9, 3.10, and 3.11. This will ensure better support for users using more recent Python versions. + +- **Update gRPC and ProtoBuf dependencies** ([#2814](https://github.com/adap/flower/pull/2814)) -- **General updates to Flower Examples** ([#2381](https://github.com/adap/flower/pull/2381)) + The `grpcio` and `protobuf` dependencies were updated to their latest versions for improved security and performance. -- **Retiring MXNet examples** The development of the MXNet fremework has ended and the project is now [archived on GitHub](https://github.com/apache/mxnet). Existing MXNet examples won't receive updates [#2724](https://github.com/adap/flower/pull/2724) +- **Introduce Docker image for Flower server** ([#2700](https://github.com/adap/flower/pull/2700), [#2688](https://github.com/adap/flower/pull/2688), [#2705](https://github.com/adap/flower/pull/2705), [#2695](https://github.com/adap/flower/pull/2695), [#2747](https://github.com/adap/flower/pull/2747), [#2746](https://github.com/adap/flower/pull/2746), [#2680](https://github.com/adap/flower/pull/2680), [#2682](https://github.com/adap/flower/pull/2682), [#2701](https://github.com/adap/flower/pull/2701)) + + The Flower server can now be run using an official Docker image. A new how-to guide explains [how to run Flower using Docker](https://flower.ai/docs/framework/how-to-run-flower-using-docker.html). An official Flower client Docker image will follow. + +- **Introduce** `flower-via-docker-compose` **example** ([#2626](https://github.com/adap/flower/pull/2626)) + +- **Introduce** `quickstart-sklearn-tabular` **example** ([#2719](https://github.com/adap/flower/pull/2719)) + +- **Introduce** `custom-metrics` **example** ([#1958](https://github.com/adap/flower/pull/1958)) + +- **Update code examples to use Flower Datasets** ([#2450](https://github.com/adap/flower/pull/2450), [#2456](https://github.com/adap/flower/pull/2456), [#2318](https://github.com/adap/flower/pull/2318), [#2712](https://github.com/adap/flower/pull/2712)) + + Several code examples were updated to use [Flower Datasets](https://flower.ai/docs/datasets/). + +- **General updates to Flower Examples** ([#2381](https://github.com/adap/flower/pull/2381), [#2805](https://github.com/adap/flower/pull/2805), [#2782](https://github.com/adap/flower/pull/2782), [#2806](https://github.com/adap/flower/pull/2806), [#2829](https://github.com/adap/flower/pull/2829), [#2825](https://github.com/adap/flower/pull/2825), [#2816](https://github.com/adap/flower/pull/2816), [#2726](https://github.com/adap/flower/pull/2726), [#2659](https://github.com/adap/flower/pull/2659), [#2655](https://github.com/adap/flower/pull/2655)) + + Many Flower code examples received substantial updates. - **Update Flower Baselines** - - HFedXGBoost [#2226](https://github.com/adap/flower/pull/2226) + - HFedXGBoost ([#2226](https://github.com/adap/flower/pull/2226), [#2771](https://github.com/adap/flower/pull/2771)) + - FedVSSL ([#2412](https://github.com/adap/flower/pull/2412)) + - FedNova ([#2179](https://github.com/adap/flower/pull/2179)) + - HeteroFL ([#2439](https://github.com/adap/flower/pull/2439)) + - FedAvgM ([#2246](https://github.com/adap/flower/pull/2246)) + - FedPara ([#2722](https://github.com/adap/flower/pull/2722)) + +- **Improve documentation** ([#2674](https://github.com/adap/flower/pull/2674), [#2480](https://github.com/adap/flower/pull/2480), [#2826](https://github.com/adap/flower/pull/2826), [#2727](https://github.com/adap/flower/pull/2727), [#2761](https://github.com/adap/flower/pull/2761), [#2900](https://github.com/adap/flower/pull/2900)) + +- **Improved testing and development infrastructure** ([#2797](https://github.com/adap/flower/pull/2797), [#2676](https://github.com/adap/flower/pull/2676), [#2644](https://github.com/adap/flower/pull/2644), [#2656](https://github.com/adap/flower/pull/2656), [#2848](https://github.com/adap/flower/pull/2848), [#2675](https://github.com/adap/flower/pull/2675), [#2735](https://github.com/adap/flower/pull/2735), [#2767](https://github.com/adap/flower/pull/2767), [#2732](https://github.com/adap/flower/pull/2732), [#2744](https://github.com/adap/flower/pull/2744), [#2681](https://github.com/adap/flower/pull/2681), [#2699](https://github.com/adap/flower/pull/2699), [#2745](https://github.com/adap/flower/pull/2745), [#2734](https://github.com/adap/flower/pull/2734), [#2731](https://github.com/adap/flower/pull/2731), [#2652](https://github.com/adap/flower/pull/2652), [#2720](https://github.com/adap/flower/pull/2720), [#2721](https://github.com/adap/flower/pull/2721), [#2717](https://github.com/adap/flower/pull/2717), [#2864](https://github.com/adap/flower/pull/2864), [#2694](https://github.com/adap/flower/pull/2694), [#2709](https://github.com/adap/flower/pull/2709), [#2658](https://github.com/adap/flower/pull/2658), [#2796](https://github.com/adap/flower/pull/2796), [#2692](https://github.com/adap/flower/pull/2692), [#2657](https://github.com/adap/flower/pull/2657), [#2813](https://github.com/adap/flower/pull/2813), [#2661](https://github.com/adap/flower/pull/2661), [#2398](https://github.com/adap/flower/pull/2398)) + + The Flower testing and development infrastructure has received substantial updates. This makes Flower 1.7 the most tested release ever. + +- **Update dependencies** ([#2753](https://github.com/adap/flower/pull/2753), [#2651](https://github.com/adap/flower/pull/2651), [#2739](https://github.com/adap/flower/pull/2739), [#2837](https://github.com/adap/flower/pull/2837), [#2788](https://github.com/adap/flower/pull/2788), [#2811](https://github.com/adap/flower/pull/2811), [#2774](https://github.com/adap/flower/pull/2774), [#2790](https://github.com/adap/flower/pull/2790), [#2751](https://github.com/adap/flower/pull/2751), [#2850](https://github.com/adap/flower/pull/2850), [#2812](https://github.com/adap/flower/pull/2812), [#2872](https://github.com/adap/flower/pull/2872), [#2736](https://github.com/adap/flower/pull/2736), [#2756](https://github.com/adap/flower/pull/2756), [#2857](https://github.com/adap/flower/pull/2857), [#2757](https://github.com/adap/flower/pull/2757), [#2810](https://github.com/adap/flower/pull/2810), [#2740](https://github.com/adap/flower/pull/2740), [#2789](https://github.com/adap/flower/pull/2789)) + +- **General improvements** ([#2803](https://github.com/adap/flower/pull/2803), [#2847](https://github.com/adap/flower/pull/2847), [#2877](https://github.com/adap/flower/pull/2877), [#2690](https://github.com/adap/flower/pull/2690), [#2889](https://github.com/adap/flower/pull/2889), [#2874](https://github.com/adap/flower/pull/2874), [#2819](https://github.com/adap/flower/pull/2819), [#2689](https://github.com/adap/flower/pull/2689), [#2457](https://github.com/adap/flower/pull/2457), [#2870](https://github.com/adap/flower/pull/2870), [#2669](https://github.com/adap/flower/pull/2669), [#2876](https://github.com/adap/flower/pull/2876), [#2885](https://github.com/adap/flower/pull/2885), [#2858](https://github.com/adap/flower/pull/2858), [#2867](https://github.com/adap/flower/pull/2867), [#2351](https://github.com/adap/flower/pull/2351), [#2886](https://github.com/adap/flower/pull/2886), [#2860](https://github.com/adap/flower/pull/2860), [#2828](https://github.com/adap/flower/pull/2828), [#2869](https://github.com/adap/flower/pull/2869), [#2875](https://github.com/adap/flower/pull/2875), [#2733](https://github.com/adap/flower/pull/2733), [#2488](https://github.com/adap/flower/pull/2488), [#2646](https://github.com/adap/flower/pull/2646), [#2879](https://github.com/adap/flower/pull/2879), [#2821](https://github.com/adap/flower/pull/2821), [#2855](https://github.com/adap/flower/pull/2855), [#2800](https://github.com/adap/flower/pull/2800), [#2807](https://github.com/adap/flower/pull/2807), [#2801](https://github.com/adap/flower/pull/2801), [#2804](https://github.com/adap/flower/pull/2804), [#2851](https://github.com/adap/flower/pull/2851), [#2787](https://github.com/adap/flower/pull/2787), [#2852](https://github.com/adap/flower/pull/2852), [#2672](https://github.com/adap/flower/pull/2672), [#2759](https://github.com/adap/flower/pull/2759)) + +### Incompatible changes + +- **Deprecate** `start_numpy_client` ([#2563](https://github.com/adap/flower/pull/2563), [#2718](https://github.com/adap/flower/pull/2718)) + + Until now, clients of type `NumPyClient` needed to be started via `start_numpy_client`. In our efforts to consolidate framework APIs, we have introduced changes, and now all client types should start via `start_client`. To continue using `NumPyClient` clients, you simply need to first call the `.to_client()` method and then pass returned `Client` object to `start_client`. The examples and the documentation have been updated accordingly. + +- **Deprecate legacy DP wrappers** ([#2749](https://github.com/adap/flower/pull/2749)) + + Legacy DP wrapper classes are deprecated, but still functional. This is in preparation for an all-new pluggable version of differential privacy support in Flower. + +- **Make optional arg** `--callable` **in** `flower-client` **a required positional arg** ([#2673](https://github.com/adap/flower/pull/2673)) + +- **Rename** `certificates` **to** `root_certificates` **in** `Driver` ([#2890](https://github.com/adap/flower/pull/2890)) - - FedVSSL [#2412](https://github.com/adap/flower/pull/2412) +- **Drop experimental** `Task` **fields** ([#2866](https://github.com/adap/flower/pull/2866), [#2865](https://github.com/adap/flower/pull/2865)) - - FedNova [#2179](https://github.com/adap/flower/pull/2179) + Experimental fields `sa`, `legacy_server_message` and `legacy_client_message` were removed from `Task` message. The removed fields are superseded by the new `RecordSet` abstraction. - - HeteroFL [#2439](https://github.com/adap/flower/pull/2439) +- **Retire MXNet examples** ([#2724](https://github.com/adap/flower/pull/2724)) - - FedAvgM [#2246](https://github.com/adap/flower/pull/2246) + The development of the MXNet fremework has ended and the project is now [archived on GitHub](https://github.com/apache/mxnet). Existing MXNet examples won't receive updates. ## v1.6.0 (2023-11-28) @@ -98,7 +177,7 @@ We would like to give our special thanks to all the contributors who made the ne - FedBN ([#2608](https://github.com/adap/flower/pull/2608), [#2615](https://github.com/adap/flower/pull/2615)) -- **General updates to Flower Examples** ([#2384](https://github.com/adap/flower/pull/2384),[#2425](https://github.com/adap/flower/pull/2425), [#2526](https://github.com/adap/flower/pull/2526), [#2302](https://github.com/adap/flower/pull/2302), [#2545](https://github.com/adap/flower/pull/2545)) +- **General updates to Flower Examples** ([#2384](https://github.com/adap/flower/pull/2384), [#2425](https://github.com/adap/flower/pull/2425), [#2526](https://github.com/adap/flower/pull/2526), [#2302](https://github.com/adap/flower/pull/2302), [#2545](https://github.com/adap/flower/pull/2545)) - **General updates to Flower Baselines** ([#2301](https://github.com/adap/flower/pull/2301), [#2305](https://github.com/adap/flower/pull/2305), [#2307](https://github.com/adap/flower/pull/2307), [#2327](https://github.com/adap/flower/pull/2327), [#2435](https://github.com/adap/flower/pull/2435), [#2462](https://github.com/adap/flower/pull/2462), [#2463](https://github.com/adap/flower/pull/2463), [#2461](https://github.com/adap/flower/pull/2461), [#2469](https://github.com/adap/flower/pull/2469), [#2466](https://github.com/adap/flower/pull/2466), [#2471](https://github.com/adap/flower/pull/2471), [#2472](https://github.com/adap/flower/pull/2472), [#2470](https://github.com/adap/flower/pull/2470)) @@ -106,7 +185,7 @@ We would like to give our special thanks to all the contributors who made the ne - **General updates to Flower SDKs** ([#2288](https://github.com/adap/flower/pull/2288), [#2429](https://github.com/adap/flower/pull/2429), [#2555](https://github.com/adap/flower/pull/2555), [#2543](https://github.com/adap/flower/pull/2543), [#2544](https://github.com/adap/flower/pull/2544), [#2597](https://github.com/adap/flower/pull/2597), [#2623](https://github.com/adap/flower/pull/2623)) -- **General improvements** ([#2309](https://github.com/adap/flower/pull/2309), [#2310](https://github.com/adap/flower/pull/2310), [2313](https://github.com/adap/flower/pull/2313), [#2316](https://github.com/adap/flower/pull/2316), [2317](https://github.com/adap/flower/pull/2317), [#2349](https://github.com/adap/flower/pull/2349), [#2360](https://github.com/adap/flower/pull/2360), [#2402](https://github.com/adap/flower/pull/2402), [#2446](https://github.com/adap/flower/pull/2446), [#2561](https://github.com/adap/flower/pull/2561), [#2273](https://github.com/adap/flower/pull/2273), [#2267](https://github.com/adap/flower/pull/2267), [#2274](https://github.com/adap/flower/pull/2274), [#2275](https://github.com/adap/flower/pull/2275), [#2432](https://github.com/adap/flower/pull/2432), [#2251](https://github.com/adap/flower/pull/2251), [#2321](https://github.com/adap/flower/pull/2321), [#1936](https://github.com/adap/flower/pull/1936), [#2408](https://github.com/adap/flower/pull/2408), [#2413](https://github.com/adap/flower/pull/2413), [#2401](https://github.com/adap/flower/pull/2401), [#2531](https://github.com/adap/flower/pull/2531), [#2534](https://github.com/adap/flower/pull/2534), [#2535](https://github.com/adap/flower/pull/2535), [#2521](https://github.com/adap/flower/pull/2521), [#2553](https://github.com/adap/flower/pull/2553), [#2596](https://github.com/adap/flower/pull/2596)) +- **General improvements** ([#2309](https://github.com/adap/flower/pull/2309), [#2310](https://github.com/adap/flower/pull/2310), [#2313](https://github.com/adap/flower/pull/2313), [#2316](https://github.com/adap/flower/pull/2316), [#2317](https://github.com/adap/flower/pull/2317), [#2349](https://github.com/adap/flower/pull/2349), [#2360](https://github.com/adap/flower/pull/2360), [#2402](https://github.com/adap/flower/pull/2402), [#2446](https://github.com/adap/flower/pull/2446), [#2561](https://github.com/adap/flower/pull/2561), [#2273](https://github.com/adap/flower/pull/2273), [#2267](https://github.com/adap/flower/pull/2267), [#2274](https://github.com/adap/flower/pull/2274), [#2275](https://github.com/adap/flower/pull/2275), [#2432](https://github.com/adap/flower/pull/2432), [#2251](https://github.com/adap/flower/pull/2251), [#2321](https://github.com/adap/flower/pull/2321), [#1936](https://github.com/adap/flower/pull/1936), [#2408](https://github.com/adap/flower/pull/2408), [#2413](https://github.com/adap/flower/pull/2413), [#2401](https://github.com/adap/flower/pull/2401), [#2531](https://github.com/adap/flower/pull/2531), [#2534](https://github.com/adap/flower/pull/2534), [#2535](https://github.com/adap/flower/pull/2535), [#2521](https://github.com/adap/flower/pull/2521), [#2553](https://github.com/adap/flower/pull/2553), [#2596](https://github.com/adap/flower/pull/2596)) Flower received many improvements under the hood, too many to list here. @@ -134,11 +213,11 @@ We would like to give our special thanks to all the contributors who made the ne The new simulation engine has been rewritten from the ground up, yet it remains fully backwards compatible. It offers much improved stability and memory handling, especially when working with GPUs. Simulations transparently adapt to different settings to scale simulation in CPU-only, CPU+GPU, multi-GPU, or multi-node multi-GPU environments. - Comprehensive documentation includes a new [how-to run simulations](https://flower.dev/docs/framework/how-to-run-simulations.html) guide, new [simulation-pytorch](https://flower.dev/docs/examples/simulation-pytorch.html) and [simulation-tensorflow](https://flower.dev/docs/examples/simulation-tensorflow.html) notebooks, and a new [YouTube tutorial series](https://www.youtube.com/watch?v=cRebUIGB5RU&list=PLNG4feLHqCWlnj8a_E1A_n5zr2-8pafTB). + Comprehensive documentation includes a new [how-to run simulations](https://flower.ai/docs/framework/how-to-run-simulations.html) guide, new [simulation-pytorch](https://flower.ai/docs/examples/simulation-pytorch.html) and [simulation-tensorflow](https://flower.ai/docs/examples/simulation-tensorflow.html) notebooks, and a new [YouTube tutorial series](https://www.youtube.com/watch?v=cRebUIGB5RU&list=PLNG4feLHqCWlnj8a_E1A_n5zr2-8pafTB). - **Restructure Flower Docs** ([#1824](https://github.com/adap/flower/pull/1824), [#1865](https://github.com/adap/flower/pull/1865), [#1884](https://github.com/adap/flower/pull/1884), [#1887](https://github.com/adap/flower/pull/1887), [#1919](https://github.com/adap/flower/pull/1919), [#1922](https://github.com/adap/flower/pull/1922), [#1920](https://github.com/adap/flower/pull/1920), [#1923](https://github.com/adap/flower/pull/1923), [#1924](https://github.com/adap/flower/pull/1924), [#1962](https://github.com/adap/flower/pull/1962), [#2006](https://github.com/adap/flower/pull/2006), [#2133](https://github.com/adap/flower/pull/2133), [#2203](https://github.com/adap/flower/pull/2203), [#2215](https://github.com/adap/flower/pull/2215), [#2122](https://github.com/adap/flower/pull/2122), [#2223](https://github.com/adap/flower/pull/2223), [#2219](https://github.com/adap/flower/pull/2219), [#2232](https://github.com/adap/flower/pull/2232), [#2233](https://github.com/adap/flower/pull/2233), [#2234](https://github.com/adap/flower/pull/2234), [#2235](https://github.com/adap/flower/pull/2235), [#2237](https://github.com/adap/flower/pull/2237), [#2238](https://github.com/adap/flower/pull/2238), [#2242](https://github.com/adap/flower/pull/2242), [#2231](https://github.com/adap/flower/pull/2231), [#2243](https://github.com/adap/flower/pull/2243), [#2227](https://github.com/adap/flower/pull/2227)) - Much effort went into a completely restructured Flower docs experience. The documentation on [flower.dev/docs](flower.dev/docs) is now divided into Flower Framework, Flower Baselines, Flower Android SDK, Flower iOS SDK, and code example projects. + Much effort went into a completely restructured Flower docs experience. The documentation on [flower.ai/docs](https://flower.ai/docs) is now divided into Flower Framework, Flower Baselines, Flower Android SDK, Flower iOS SDK, and code example projects. - **Introduce Flower Swift SDK** ([#1858](https://github.com/adap/flower/pull/1858), [#1897](https://github.com/adap/flower/pull/1897)) @@ -216,7 +295,7 @@ We would like to give our special thanks to all the contributors who made the ne - **Introduce support for XGBoost (**`FedXgbNnAvg` **strategy and example)** ([#1694](https://github.com/adap/flower/pull/1694), [#1709](https://github.com/adap/flower/pull/1709), [#1715](https://github.com/adap/flower/pull/1715), [#1717](https://github.com/adap/flower/pull/1717), [#1763](https://github.com/adap/flower/pull/1763), [#1795](https://github.com/adap/flower/pull/1795)) - XGBoost is a tree-based ensemble machine learning algorithm that uses gradient boosting to improve model accuracy. We added a new `FedXgbNnAvg` [strategy](https://github.com/adap/flower/tree/main/src/py/flwr/server/strategy/fedxgb_nn_avg.py), and a [code example](https://github.com/adap/flower/tree/main/examples/quickstart_xgboost_horizontal) that demonstrates the usage of this new strategy in an XGBoost project. + XGBoost is a tree-based ensemble machine learning algorithm that uses gradient boosting to improve model accuracy. We added a new `FedXgbNnAvg` [strategy](https://github.com/adap/flower/tree/main/src/py/flwr/server/strategy/fedxgb_nn_avg.py), and a [code example](https://github.com/adap/flower/tree/main/examples/xgboost-quickstart) that demonstrates the usage of this new strategy in an XGBoost project. - **Introduce iOS SDK (preview)** ([#1621](https://github.com/adap/flower/pull/1621), [#1764](https://github.com/adap/flower/pull/1764)) @@ -224,11 +303,11 @@ We would like to give our special thanks to all the contributors who made the ne - **Introduce new "What is Federated Learning?" tutorial** ([#1657](https://github.com/adap/flower/pull/1657), [#1721](https://github.com/adap/flower/pull/1721)) - A new [entry-level tutorial](https://flower.dev/docs/framework/tutorial-what-is-federated-learning.html) in our documentation explains the basics of Fedetated Learning. It enables anyone who's unfamiliar with Federated Learning to start their journey with Flower. Forward it to anyone who's interested in Federated Learning! + A new [entry-level tutorial](https://flower.ai/docs/framework/tutorial-what-is-federated-learning.html) in our documentation explains the basics of Fedetated Learning. It enables anyone who's unfamiliar with Federated Learning to start their journey with Flower. Forward it to anyone who's interested in Federated Learning! - **Introduce new Flower Baseline: FedProx MNIST** ([#1513](https://github.com/adap/flower/pull/1513), [#1680](https://github.com/adap/flower/pull/1680), [#1681](https://github.com/adap/flower/pull/1681), [#1679](https://github.com/adap/flower/pull/1679)) - This new baseline replicates the MNIST+CNN task from the paper [Federated Optimization in Heterogeneous Networks (Li et al., 2018)](https://arxiv.org/abs/1812.06127). It uses the `FedProx` strategy, which aims at making convergence more robust in heterogenous settings. + This new baseline replicates the MNIST+CNN task from the paper [Federated Optimization in Heterogeneous Networks (Li et al., 2018)](https://arxiv.org/abs/1812.06127). It uses the `FedProx` strategy, which aims at making convergence more robust in heterogeneous settings. - **Introduce new Flower Baseline: FedAvg FEMNIST** ([#1655](https://github.com/adap/flower/pull/1655)) @@ -250,7 +329,7 @@ We would like to give our special thanks to all the contributors who made the ne - **Add new example using** `TabNet` **and Flower** ([#1725](https://github.com/adap/flower/pull/1725)) - TabNet is a powerful and flexible framework for training machine learning models on tabular data. We now have a federated example using Flower: [https://github.com/adap/flower/tree/main/examples/tabnet](https://github.com/adap/flower/tree/main/examples/quickstart_tabnet). + TabNet is a powerful and flexible framework for training machine learning models on tabular data. We now have a federated example using Flower: [quickstart-tabnet](https://github.com/adap/flower/tree/main/examples/quickstart-tabnet). - **Add new how-to guide for monitoring simulations** ([#1649](https://github.com/adap/flower/pull/1649)) @@ -292,7 +371,7 @@ We would like to give our special thanks to all the contributors who made the ne - **Add new example of Federated Learning using fastai and Flower** ([#1598](https://github.com/adap/flower/pull/1598)) - A new code example (`quickstart_fastai`) demonstrates federated learning with [fastai](https://www.fast.ai/) and Flower. You can find it here: [quickstart_fastai](https://github.com/adap/flower/tree/main/examples/quickstart_fastai). + A new code example (`quickstart-fastai`) demonstrates federated learning with [fastai](https://www.fast.ai/) and Flower. You can find it here: [quickstart-fastai](https://github.com/adap/flower/tree/main/examples/quickstart-fastai). - **Make Android example compatible with** `flwr >= 1.0.0` **and the latest versions of Android** ([#1603](https://github.com/adap/flower/pull/1603)) @@ -338,7 +417,7 @@ We would like to give our special thanks to all the contributors who made the ne - **Introduce new Flower Baseline: FedAvg MNIST** ([#1497](https://github.com/adap/flower/pull/1497), [#1552](https://github.com/adap/flower/pull/1552)) - Over the coming weeks, we will be releasing a number of new reference implementations useful especially to FL newcomers. They will typically revisit well known papers from the literature, and be suitable for integration in your own application or for experimentation, in order to deepen your knowledge of FL in general. Today's release is the first in this series. [Read more.](https://flower.dev/blog/2023-01-12-fl-starter-pack-fedavg-mnist-cnn/) + Over the coming weeks, we will be releasing a number of new reference implementations useful especially to FL newcomers. They will typically revisit well known papers from the literature, and be suitable for integration in your own application or for experimentation, in order to deepen your knowledge of FL in general. Today's release is the first in this series. [Read more.](https://flower.ai/blog/2023-01-12-fl-starter-pack-fedavg-mnist-cnn/) - **Improve GPU support in simulations** ([#1555](https://github.com/adap/flower/pull/1555)) @@ -348,16 +427,16 @@ We would like to give our special thanks to all the contributors who made the ne Some users reported that Jupyter Notebooks have not always been easy to use on GPU instances. We listened and made improvements to all of our Jupyter notebooks! Check out the updated notebooks here: - - [An Introduction to Federated Learning](https://flower.dev/docs/framework/tutorial-get-started-with-flower-pytorch.html) - - [Strategies in Federated Learning](https://flower.dev/docs/framework/tutorial-use-a-federated-learning-strategy-pytorch.html) - - [Building a Strategy](https://flower.dev/docs/framework/tutorial-build-a-strategy-from-scratch-pytorch.html) - - [Client and NumPyClient](https://flower.dev/docs/framework/tutorial-customize-the-client-pytorch.html) + - [An Introduction to Federated Learning](https://flower.ai/docs/framework/tutorial-get-started-with-flower-pytorch.html) + - [Strategies in Federated Learning](https://flower.ai/docs/framework/tutorial-use-a-federated-learning-strategy-pytorch.html) + - [Building a Strategy](https://flower.ai/docs/framework/tutorial-build-a-strategy-from-scratch-pytorch.html) + - [Client and NumPyClient](https://flower.ai/docs/framework/tutorial-customize-the-client-pytorch.html) - **Introduce optional telemetry** ([#1533](https://github.com/adap/flower/pull/1533), [#1544](https://github.com/adap/flower/pull/1544), [#1584](https://github.com/adap/flower/pull/1584)) After a [request for feedback](https://github.com/adap/flower/issues/1534) from the community, the Flower open-source project introduces optional collection of *anonymous* usage metrics to make well-informed decisions to improve Flower. Doing this enables the Flower team to understand how Flower is used and what challenges users might face. - **Flower is a friendly framework for collaborative AI and data science.** Staying true to this statement, Flower makes it easy to disable telemetry for users who do not want to share anonymous usage metrics. [Read more.](https://flower.dev/docs/telemetry.html). + **Flower is a friendly framework for collaborative AI and data science.** Staying true to this statement, Flower makes it easy to disable telemetry for users who do not want to share anonymous usage metrics. [Read more.](https://flower.ai/docs/telemetry.html). - **Introduce (experimental) Driver API** ([#1520](https://github.com/adap/flower/pull/1520), [#1525](https://github.com/adap/flower/pull/1525), [#1545](https://github.com/adap/flower/pull/1545), [#1546](https://github.com/adap/flower/pull/1546), [#1550](https://github.com/adap/flower/pull/1550), [#1551](https://github.com/adap/flower/pull/1551), [#1567](https://github.com/adap/flower/pull/1567)) @@ -371,7 +450,7 @@ We would like to give our special thanks to all the contributors who made the ne - **Add new Federated Analytics with Pandas example** ([#1469](https://github.com/adap/flower/pull/1469), [#1535](https://github.com/adap/flower/pull/1535)) - A new code example (`quickstart_pandas`) demonstrates federated analytics with Pandas and Flower. You can find it here: [quickstart_pandas](https://github.com/adap/flower/tree/main/examples/quickstart_pandas). + A new code example (`quickstart-pandas`) demonstrates federated analytics with Pandas and Flower. You can find it here: [quickstart-pandas](https://github.com/adap/flower/tree/main/examples/quickstart-pandas). - **Add new strategies: Krum and MultiKrum** ([#1481](https://github.com/adap/flower/pull/1481)) @@ -389,7 +468,7 @@ We would like to give our special thanks to all the contributors who made the ne As usual, the documentation has improved quite a bit. It is another step in our effort to make the Flower documentation the best documentation of any project. Stay tuned and as always, feel free to provide feedback! - One highlight is the new [first time contributor guide](https://flower.dev/docs/first-time-contributors.html): if you've never contributed on GitHub before, this is the perfect place to start! + One highlight is the new [first time contributor guide](https://flower.ai/docs/first-time-contributors.html): if you've never contributed on GitHub before, this is the perfect place to start! ### Incompatible changes @@ -470,7 +549,7 @@ None We would like to give our **special thanks** to all the contributors who made Flower 1.0 possible (in reverse [GitHub Contributors](https://github.com/adap/flower/graphs/contributors) order): -[@rtaiello](https://github.com/rtaiello), [@g-pichler](https://github.com/g-pichler), [@rob-luke](https://github.com/rob-luke), [@andreea-zaharia](https://github.com/andreea-zaharia), [@kinshukdua](https://github.com/kinshukdua), [@nfnt](https://github.com/nfnt), [@tatiana-s](https://github.com/tatiana-s), [@TParcollet](https://github.com/TParcollet), [@vballoli](https://github.com/vballoli), [@negedng](https://github.com/negedng), [@RISHIKESHAVAN](https://github.com/RISHIKESHAVAN), [@hei411](https://github.com/hei411), [@SebastianSpeitel](https://github.com/SebastianSpeitel), [@AmitChaulwar](https://github.com/AmitChaulwar), [@Rubiel1](https://github.com/Rubiel1), [@FANTOME-PAN](https://github.com/FANTOME-PAN), [@Rono-BC](https://github.com/Rono-BC), [@lbhm](https://github.com/lbhm), [@sishtiaq](https://github.com/sishtiaq), [@remde](https://github.com/remde), [@Jueun-Park](https://github.com/Jueun-Park), [@architjen](https://github.com/architjen), [@PratikGarai](https://github.com/PratikGarai), [@mrinaald](https://github.com/mrinaald), [@zliel](https://github.com/zliel), [@MeiruiJiang](https://github.com/MeiruiJiang), [@sandracl72](https://github.com/sandracl72), [@gubertoli](https://github.com/gubertoli), [@Vingt100](https://github.com/Vingt100), [@MakGulati](https://github.com/MakGulati), [@cozek](https://github.com/cozek), [@jafermarq](https://github.com/jafermarq), [@sisco0](https://github.com/sisco0), [@akhilmathurs](https://github.com/akhilmathurs), [@CanTuerk](https://github.com/CanTuerk), [@mariaboerner1987](https://github.com/mariaboerner1987), [@pedropgusmao](https://github.com/pedropgusmao), [@tanertopal](https://github.com/tanertopal), [@danieljanes](https://github.com/danieljanes). +[@rtaiello](https://github.com/rtaiello), [@g-pichler](https://github.com/g-pichler), [@rob-luke](https://github.com/rob-luke), [@andreea-zaharia](https://github.com/andreea-zaharia), [@kinshukdua](https://github.com/kinshukdua), [@nfnt](https://github.com/nfnt), [@tatiana-s](https://github.com/tatiana-s), [@TParcollet](https://github.com/TParcollet), [@vballoli](https://github.com/vballoli), [@negedng](https://github.com/negedng), [@RISHIKESHAVAN](https://github.com/RISHIKESHAVAN), [@hei411](https://github.com/hei411), [@SebastianSpeitel](https://github.com/SebastianSpeitel), [@AmitChaulwar](https://github.com/AmitChaulwar), [@Rubiel1](https://github.com/Rubiel1), [@FANTOME-PAN](https://github.com/FANTOME-PAN), [@Rono-BC](https://github.com/Rono-BC), [@lbhm](https://github.com/lbhm), [@sishtiaq](https://github.com/sishtiaq), [@remde](https://github.com/remde), [@Jueun-Park](https://github.com/Jueun-Park), [@architjen](https://github.com/architjen), [@PratikGarai](https://github.com/PratikGarai), [@mrinaald](https://github.com/mrinaald), [@zliel](https://github.com/zliel), [@MeiruiJiang](https://github.com/MeiruiJiang), [@sancarlim](https://github.com/sancarlim), [@gubertoli](https://github.com/gubertoli), [@Vingt100](https://github.com/Vingt100), [@MakGulati](https://github.com/MakGulati), [@cozek](https://github.com/cozek), [@jafermarq](https://github.com/jafermarq), [@sisco0](https://github.com/sisco0), [@akhilmathurs](https://github.com/akhilmathurs), [@CanTuerk](https://github.com/CanTuerk), [@mariaboerner1987](https://github.com/mariaboerner1987), [@pedropgusmao](https://github.com/pedropgusmao), [@tanertopal](https://github.com/tanertopal), [@danieljanes](https://github.com/danieljanes). ### Incompatible changes @@ -578,7 +657,7 @@ We would like to give our **special thanks** to all the contributors who made Fl - **Flower Baselines (preview): FedOpt, FedBN, FedAvgM** ([#919](https://github.com/adap/flower/pull/919), [#1127](https://github.com/adap/flower/pull/1127), [#914](https://github.com/adap/flower/pull/914)) - The first preview release of Flower Baselines has arrived! We're kickstarting Flower Baselines with implementations of FedOpt (FedYogi, FedAdam, FedAdagrad), FedBN, and FedAvgM. Check the documentation on how to use [Flower Baselines](https://flower.dev/docs/using-baselines.html). With this first preview release we're also inviting the community to [contribute their own baselines](https://flower.dev/docs/contributing-baselines.html). + The first preview release of Flower Baselines has arrived! We're kickstarting Flower Baselines with implementations of FedOpt (FedYogi, FedAdam, FedAdagrad), FedBN, and FedAvgM. Check the documentation on how to use [Flower Baselines](https://flower.ai/docs/using-baselines.html). With this first preview release we're also inviting the community to [contribute their own baselines](https://flower.ai/docs/baselines/how-to-contribute-baselines.html). - **C++ client SDK (preview) and code example** ([#1111](https://github.com/adap/flower/pull/1111)) @@ -624,7 +703,7 @@ We would like to give our **special thanks** to all the contributors who made Fl - New option to keep Ray running if Ray was already initialized in `start_simulation` ([#1177](https://github.com/adap/flower/pull/1177)) - Add support for custom `ClientManager` as a `start_simulation` parameter ([#1171](https://github.com/adap/flower/pull/1171)) - - New documentation for [implementing strategies](https://flower.dev/docs/framework/how-to-implement-strategies.html) ([#1097](https://github.com/adap/flower/pull/1097), [#1175](https://github.com/adap/flower/pull/1175)) + - New documentation for [implementing strategies](https://flower.ai/docs/framework/how-to-implement-strategies.html) ([#1097](https://github.com/adap/flower/pull/1097), [#1175](https://github.com/adap/flower/pull/1175)) - New mobile-friendly documentation theme ([#1174](https://github.com/adap/flower/pull/1174)) - Limit version range for (optional) `ray` dependency to include only compatible releases (`>=1.9.2,<1.12.0`) ([#1205](https://github.com/adap/flower/pull/1205)) @@ -744,7 +823,7 @@ We would like to give our **special thanks** to all the contributors who made Fl - **Renamed q-FedAvg strategy** ([#802](https://github.com/adap/flower/pull/802)) - The strategy named `QffedAvg` was renamed to `QFedAvg` to better reflect the notation given in the original paper (q-FFL is the optimization objective, q-FedAvg is the proposed solver). Note the the original (now deprecated) `QffedAvg` class is still available for compatibility reasons (it will be removed in a future release). + The strategy named `QffedAvg` was renamed to `QFedAvg` to better reflect the notation given in the original paper (q-FFL is the optimization objective, q-FedAvg is the proposed solver). Note the original (now deprecated) `QffedAvg` class is still available for compatibility reasons (it will be removed in a future release). - **Deprecated and renamed code example** `simulation_pytorch` **to** `simulation_pytorch_legacy` ([#791](https://github.com/adap/flower/pull/791)) @@ -763,9 +842,9 @@ We would like to give our **special thanks** to all the contributors who made Fl The Flower server is now fully task-agnostic, all remaining instances of task-specific metrics (such as `accuracy`) have been replaced by custom metrics dictionaries. Flower 0.15 introduced the capability to pass a dictionary containing custom metrics from client to server. As of this release, custom metrics replace task-specific metrics on the server. - Custom metric dictionaries are now used in two user-facing APIs: they are returned from Strategy methods `aggregate_fit`/`aggregate_evaluate` and they enable evaluation functions passed to build-in strategies (via `eval_fn`) to return more than two evaluation metrics. Strategies can even return *aggregated* metrics dictionaries for the server to keep track of. + Custom metric dictionaries are now used in two user-facing APIs: they are returned from Strategy methods `aggregate_fit`/`aggregate_evaluate` and they enable evaluation functions passed to built-in strategies (via `eval_fn`) to return more than two evaluation metrics. Strategies can even return *aggregated* metrics dictionaries for the server to keep track of. - Stratey implementations should migrate their `aggregate_fit` and `aggregate_evaluate` methods to the new return type (e.g., by simply returning an empty `{}`), server-side evaluation functions should migrate from `return loss, accuracy` to `return loss, {"accuracy": accuracy}`. + Strategy implementations should migrate their `aggregate_fit` and `aggregate_evaluate` methods to the new return type (e.g., by simply returning an empty `{}`), server-side evaluation functions should migrate from `return loss, accuracy` to `return loss, {"accuracy": accuracy}`. Flower 0.15-style return types are deprecated (but still supported), compatibility will be removed in a future release. @@ -785,7 +864,7 @@ We would like to give our **special thanks** to all the contributors who made Fl The Flower server is now fully serialization-agnostic. Prior usage of class `Weights` (which represents parameters as deserialized NumPy ndarrays) was replaced by class `Parameters` (e.g., in `Strategy`). `Parameters` objects are fully serialization-agnostic and represents parameters as byte arrays, the `tensor_type` attributes indicates how these byte arrays should be interpreted (e.g., for serialization/deserialization). - Built-in strategies implement this approach by handling serialization and deserialization to/from `Weights` internally. Custom/3rd-party Strategy implementations should update to the slighly changed Strategy method definitions. Strategy authors can consult PR [#721](https://github.com/adap/flower/pull/721) to see how strategies can easily migrate to the new format. + Built-in strategies implement this approach by handling serialization and deserialization to/from `Weights` internally. Custom/3rd-party Strategy implementations should update to the slightly changed Strategy method definitions. Strategy authors can consult PR [#721](https://github.com/adap/flower/pull/721) to see how strategies can easily migrate to the new format. - Deprecated `flwr.server.Server.evaluate`, use `flwr.server.Server.evaluate_round` instead ([#717](https://github.com/adap/flower/pull/717)) @@ -806,7 +885,7 @@ What's new? ) model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"]) - # Create strategy and initilize parameters on the server-side + # Create strategy and initialize parameters on the server-side strategy = fl.server.strategy.FedAvg( # ... (other constructor arguments) initial_parameters=model.get_weights(), diff --git a/doc/source/ref-example-projects.rst b/doc/source/ref-example-projects.rst index b47bd8e48997..bade86dfaa54 100644 --- a/doc/source/ref-example-projects.rst +++ b/doc/source/ref-example-projects.rst @@ -23,8 +23,8 @@ The TensorFlow/Keras quickstart example shows CIFAR-10 image classification with MobileNetV2: - `Quickstart TensorFlow (Code) `_ -- `Quickstart TensorFlow (Tutorial) `_ -- `Quickstart TensorFlow (Blog Post) `_ +- :doc:`Quickstart TensorFlow (Tutorial) ` +- `Quickstart TensorFlow (Blog Post) `_ Quickstart PyTorch @@ -34,7 +34,7 @@ The PyTorch quickstart example shows CIFAR-10 image classification with a simple Convolutional Neural Network: - `Quickstart PyTorch (Code) `_ -- `Quickstart PyTorch (Tutorial) `_ +- :doc:`Quickstart PyTorch (Tutorial) ` PyTorch: From Centralized To Federated @@ -43,7 +43,7 @@ PyTorch: From Centralized To Federated This example shows how a regular PyTorch project can be federated using Flower: - `PyTorch: From Centralized To Federated (Code) `_ -- `PyTorch: From Centralized To Federated (Tutorial) `_ +- :doc:`PyTorch: From Centralized To Federated (Tutorial) ` Federated Learning on Raspberry Pi and Nvidia Jetson @@ -52,7 +52,7 @@ Federated Learning on Raspberry Pi and Nvidia Jetson This example shows how Flower can be used to build a federated learning system that run across Raspberry Pi and Nvidia Jetson: - `Federated Learning on Raspberry Pi and Nvidia Jetson (Code) `_ -- `Federated Learning on Raspberry Pi and Nvidia Jetson (Blog Post) `_ +- `Federated Learning on Raspberry Pi and Nvidia Jetson (Blog Post) `_ @@ -60,7 +60,7 @@ Legacy Examples (`flwr_example`) -------------------------------- .. warning:: - The useage examples in `flwr_example` are deprecated and will be removed in + The usage examples in `flwr_example` are deprecated and will be removed in the future. New examples are provided as standalone projects in `examples `_. @@ -114,7 +114,7 @@ For more details, see :code:`src/py/flwr_example/pytorch_cifar`. ImageNet-2012 Image Classification ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -`ImageNet-2012 `_ is one of the major computer +`ImageNet-2012 `_ is one of the major computer vision datasets. The Flower ImageNet example uses PyTorch to train a ResNet-18 classifier in a federated learning setup with ten clients. diff --git a/doc/source/ref-faq.rst b/doc/source/ref-faq.rst index 13c44bc64b0e..26b7dca4a0a7 100644 --- a/doc/source/ref-faq.rst +++ b/doc/source/ref-faq.rst @@ -3,23 +3,23 @@ FAQ This page collects answers to commonly asked questions about Federated Learning with Flower. -.. dropdown:: :fa:`eye,mr-1` Can Flower run on Juptyter Notebooks / Google Colab? +.. dropdown:: :fa:`eye,mr-1` Can Flower run on Jupyter Notebooks / Google Colab? Yes, it can! Flower even comes with a few under-the-hood optimizations to make it work even better on Colab. Here's a quickstart example: - + * `Flower simulation PyTorch `_ * `Flower simulation TensorFlow/Keras `_ .. dropdown:: :fa:`eye,mr-1` How can I run Federated Learning on a Raspberry Pi? - Find the `blog post about federated learning on embedded device here `_ and the corresponding `GitHub code example `_. + Find the `blog post about federated learning on embedded device here `_ and the corresponding `GitHub code example `_. .. dropdown:: :fa:`eye,mr-1` Does Flower support federated learning on Android devices? - Yes, it does. Please take a look at our `blog post `_ or check out the code examples: + Yes, it does. Please take a look at our `blog post `_ or check out the code examples: - * `Android Kotlin example `_ - * `Android Java example `_ + * `Android Kotlin example `_ + * `Android Java example `_ .. dropdown:: :fa:`eye,mr-1` Can I combine federated learning with blockchain? @@ -27,6 +27,6 @@ This page collects answers to commonly asked questions about Federated Learning * `Flower meets Nevermined GitHub Repository `_. * `Flower meets Nevermined YouTube video `_. - * `Flower meets KOSMoS `_. + * `Flower meets KOSMoS `_. * `Flower meets Talan blog post `_ . * `Flower meets Talan GitHub Repository `_ . diff --git a/doc/source/ref-telemetry.md b/doc/source/ref-telemetry.md index 206e641d8b41..49efef5c8559 100644 --- a/doc/source/ref-telemetry.md +++ b/doc/source/ref-telemetry.md @@ -41,7 +41,7 @@ Flower telemetry collects the following metrics: **Source.** Flower telemetry tries to store a random source ID in `~/.flwr/source` the first time a telemetry event is generated. The source ID is important to identify whether an issue is recurring or whether an issue is triggered by multiple clusters running concurrently (which often happens in simulation). For example, if a device runs multiple workloads at the same time, and this results in an issue, then, in order to reproduce the issue, multiple workloads must be started at the same time. -You may delete the source ID at any time. If you wish for all events logged under a specific source ID to be deleted, you can send a deletion request mentioning the source ID to `telemetry@flower.dev`. All events related to that source ID will then be permanently deleted. +You may delete the source ID at any time. If you wish for all events logged under a specific source ID to be deleted, you can send a deletion request mentioning the source ID to `telemetry@flower.ai`. All events related to that source ID will then be permanently deleted. We will not collect any personally identifiable information. If you think any of the metrics collected could be misused in any way, please [get in touch with us](#how-to-contact-us). We will update this page to reflect any changes to the metrics collected and publish changes in the changelog. @@ -63,4 +63,4 @@ FLWR_TELEMETRY_ENABLED=0 FLWR_TELEMETRY_LOGGING=1 python server.py # or client.p ## How to contact us -We want to hear from you. If you have any feedback or ideas on how to improve the way we handle anonymous usage metrics, reach out to us via [Slack](https://flower.dev/join-slack/) (channel `#telemetry`) or email (`telemetry@flower.dev`). +We want to hear from you. If you have any feedback or ideas on how to improve the way we handle anonymous usage metrics, reach out to us via [Slack](https://flower.ai/join-slack/) (channel `#telemetry`) or email (`telemetry@flower.ai`). diff --git a/doc/source/tutorial-quickstart-huggingface.rst b/doc/source/tutorial-quickstart-huggingface.rst index 7718e6558456..7d8128230901 100644 --- a/doc/source/tutorial-quickstart-huggingface.rst +++ b/doc/source/tutorial-quickstart-huggingface.rst @@ -9,8 +9,8 @@ Quickstart 🤗 Transformers Let's build a federated learning system using Hugging Face Transformers and Flower! -We will leverage Hugging Face to federate the training of language models over multiple clients using Flower. -More specifically, we will fine-tune a pre-trained Transformer model (distilBERT) +We will leverage Hugging Face to federate the training of language models over multiple clients using Flower. +More specifically, we will fine-tune a pre-trained Transformer model (distilBERT) for sequence classification over a dataset of IMDB ratings. The end goal is to detect if a movie rating is positive or negative. @@ -32,8 +32,8 @@ Standard Hugging Face workflow Handling the data ^^^^^^^^^^^^^^^^^ -To fetch the IMDB dataset, we will use Hugging Face's :code:`datasets` library. -We then need to tokenize the data and create :code:`PyTorch` dataloaders, +To fetch the IMDB dataset, we will use Hugging Face's :code:`datasets` library. +We then need to tokenize the data and create :code:`PyTorch` dataloaders, this is all done in the :code:`load_data` function: .. code-block:: python @@ -80,8 +80,8 @@ this is all done in the :code:`load_data` function: Training and testing the model ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Once we have a way of creating our trainloader and testloader, -we can take care of the training and testing. +Once we have a way of creating our trainloader and testloader, +we can take care of the training and testing. This is very similar to any :code:`PyTorch` training or testing loop: .. code-block:: python @@ -120,12 +120,12 @@ This is very similar to any :code:`PyTorch` training or testing loop: Creating the model itself ^^^^^^^^^^^^^^^^^^^^^^^^^ -To create the model itself, +To create the model itself, we will just load the pre-trained distillBERT model using Hugging Face’s :code:`AutoModelForSequenceClassification` : .. code-block:: python - from transformers import AutoModelForSequenceClassification + from transformers import AutoModelForSequenceClassification net = AutoModelForSequenceClassification.from_pretrained( CHECKPOINT, num_labels=2 @@ -138,8 +138,8 @@ Federating the example Creating the IMDBClient ^^^^^^^^^^^^^^^^^^^^^^^ -To federate our example to multiple clients, -we first need to write our Flower client class (inheriting from :code:`flwr.client.NumPyClient`). +To federate our example to multiple clients, +we first need to write our Flower client class (inheriting from :code:`flwr.client.NumPyClient`). This is very easy, as our model is a standard :code:`PyTorch` model: .. code-block:: python @@ -166,17 +166,17 @@ This is very easy, as our model is a standard :code:`PyTorch` model: return float(loss), len(testloader), {"accuracy": float(accuracy)} -The :code:`get_parameters` function lets the server get the client's parameters. -Inversely, the :code:`set_parameters` function allows the server to send its parameters to the client. -Finally, the :code:`fit` function trains the model locally for the client, -and the :code:`evaluate` function tests the model locally and returns the relevant metrics. +The :code:`get_parameters` function lets the server get the client's parameters. +Inversely, the :code:`set_parameters` function allows the server to send its parameters to the client. +Finally, the :code:`fit` function trains the model locally for the client, +and the :code:`evaluate` function tests the model locally and returns the relevant metrics. Starting the server ^^^^^^^^^^^^^^^^^^^ -Now that we have a way to instantiate clients, we need to create our server in order to aggregate the results. -Using Flower, this can be done very easily by first choosing a strategy (here, we are using :code:`FedAvg`, -which will define the global weights as the average of all the clients' weights at each round) +Now that we have a way to instantiate clients, we need to create our server in order to aggregate the results. +Using Flower, this can be done very easily by first choosing a strategy (here, we are using :code:`FedAvg`, +which will define the global weights as the average of all the clients' weights at each round) and then using the :code:`flwr.server.start_server` function: .. code-block:: python @@ -186,7 +186,7 @@ and then using the :code:`flwr.server.start_server` function: losses = [num_examples * m["loss"] for num_examples, m in metrics] examples = [num_examples for num_examples, _ in metrics] return {"accuracy": sum(accuracies) / sum(examples), "loss": sum(losses) / sum(examples)} - + # Define strategy strategy = fl.server.strategy.FedAvg( fraction_fit=1.0, @@ -202,7 +202,7 @@ and then using the :code:`flwr.server.start_server` function: ) -The :code:`weighted_average` function is there to provide a way to aggregate the metrics distributed amongst +The :code:`weighted_average` function is there to provide a way to aggregate the metrics distributed amongst the clients (basically this allows us to display a nice average accuracy and loss for every round). Putting everything together @@ -212,19 +212,18 @@ We can now start client instances using: .. code-block:: python - fl.client.start_numpy_client( - server_address="127.0.0.1:8080", - client=IMDBClient() + fl.client.start_client( + server_address="127.0.0.1:8080", + client=IMDBClient().to_client() ) And they will be able to connect to the server and start the federated training. -If you want to check out everything put together, -you should check out the full code example: -[https://github.com/adap/flower/tree/main/examples/quickstart-huggingface](https://github.com/adap/flower/tree/main/examples/quickstart-huggingface). +If you want to check out everything put together, +you should check out the `full code example `_ . -Of course, this is a very basic example, and a lot can be added or modified, +Of course, this is a very basic example, and a lot can be added or modified, it was just to showcase how simply we could federate a Hugging Face workflow using Flower. Note that in this example we used :code:`PyTorch`, but we could have very well used :code:`TensorFlow`. diff --git a/doc/source/tutorial-quickstart-ios.rst b/doc/source/tutorial-quickstart-ios.rst index 7c8007baaa75..e4315ce569fb 100644 --- a/doc/source/tutorial-quickstart-ios.rst +++ b/doc/source/tutorial-quickstart-ios.rst @@ -7,14 +7,14 @@ Quickstart iOS .. meta:: :description: Read this Federated Learning quickstart tutorial for creating an iOS app using Flower to train a neural network on MNIST. -In this tutorial we will learn how to train a Neural Network on MNIST using Flower and CoreML on iOS devices. +In this tutorial we will learn how to train a Neural Network on MNIST using Flower and CoreML on iOS devices. -First of all, for running the Flower Python server, it is recommended to create a virtual environment and run everything within a `virtualenv `_. +First of all, for running the Flower Python server, it is recommended to create a virtual environment and run everything within a :doc:`virtualenv `. For the Flower client implementation in iOS, it is recommended to use Xcode as our IDE. -Our example consists of one Python *server* and two iPhone *clients* that all have the same model. +Our example consists of one Python *server* and two iPhone *clients* that all have the same model. -*Clients* are responsible for generating individual weight updates for the model based on their local datasets. +*Clients* are responsible for generating individual weight updates for the model based on their local datasets. These updates are then sent to the *server* which will aggregate them to produce a better model. Finally, the *server* sends this improved version of the model back to each *client*. A complete cycle of weight updates is called a *round*. @@ -44,10 +44,10 @@ For simplicity reasons we will use the complete Flower client with CoreML, that public func getParameters() -> GetParametersRes { let parameters = parameters.weightsToParameters() let status = Status(code: .ok, message: String()) - + return GetParametersRes(parameters: parameters, status: status) } - + /// Calls the routine to fit the local model /// /// - Returns: The result from the local training, e.g., updated parameters @@ -55,17 +55,17 @@ For simplicity reasons we will use the complete Flower client with CoreML, that let status = Status(code: .ok, message: String()) let result = runMLTask(configuration: parameters.parametersToWeights(parameters: ins.parameters), task: .train) let parameters = parameters.weightsToParameters() - + return FitRes(parameters: parameters, numExamples: result.numSamples, status: status) } - + /// Calls the routine to evaluate the local model /// /// - Returns: The result from the evaluation, e.g., loss public func evaluate(ins: EvaluateIns) -> EvaluateRes { let status = Status(code: .ok, message: String()) let result = runMLTask(configuration: parameters.parametersToWeights(parameters: ins.parameters), task: .test) - + return EvaluateRes(loss: Float(result.loss), numExamples: result.numSamples, status: status) } @@ -88,18 +88,18 @@ For the MNIST dataset, we need to preprocess it into :code:`MLBatchProvider` obj // prepare train dataset let trainBatchProvider = DataLoader.trainBatchProvider() { _ in } - + // prepare test dataset let testBatchProvider = DataLoader.testBatchProvider() { _ in } - + // load them together - let dataLoader = MLDataLoader(trainBatchProvider: trainBatchProvider, + let dataLoader = MLDataLoader(trainBatchProvider: trainBatchProvider, testBatchProvider: testBatchProvider) Since CoreML does not allow the model parameters to be seen before training, and accessing the model parameters during or after the training can only be done by specifying the layer name, -we need to know this informations beforehand, through looking at the model specification, which are written as proto files. The implementation can be seen in :code:`MLModelInspect`. +we need to know this information beforehand, through looking at the model specification, which are written as proto files. The implementation can be seen in :code:`MLModelInspect`. -After we have all of the necessary informations, let's create our Flower client. +After we have all of the necessary information, let's create our Flower client. .. code-block:: swift @@ -122,7 +122,7 @@ Then start the Flower gRPC client and start communicating to the server by passi self.flwrGRPC.startFlwrGRPC(client: self.mlFlwrClient) That's it for the client. We only have to implement :code:`Client` or call the provided -:code:`MLFlwrClient` and call :code:`startFlwrGRPC()`. The attribute :code:`hostname` and :code:`port` tells the client which server to connect to. +:code:`MLFlwrClient` and call :code:`startFlwrGRPC()`. The attribute :code:`hostname` and :code:`port` tells the client which server to connect to. This can be done by entering the hostname and port in the application before clicking the start button to start the federated learning process. Flower Server diff --git a/doc/source/tutorial-quickstart-jax.rst b/doc/source/tutorial-quickstart-jax.rst index 945f231e112e..d2b9243e2bb3 100644 --- a/doc/source/tutorial-quickstart-jax.rst +++ b/doc/source/tutorial-quickstart-jax.rst @@ -265,7 +265,7 @@ Having defined the federation process, we can run it. # Start Flower client client = FlowerClient(params, grad_fn, train_x, train_y, test_x, test_y) - fl.client.start_numpy_client(server_address="0.0.0.0:8080", client) + fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client()) if __name__ == "__main__": main() diff --git a/doc/source/tutorial-quickstart-mxnet.rst b/doc/source/tutorial-quickstart-mxnet.rst index ff8d4b2087dd..fe582f793280 100644 --- a/doc/source/tutorial-quickstart-mxnet.rst +++ b/doc/source/tutorial-quickstart-mxnet.rst @@ -4,18 +4,18 @@ Quickstart MXNet ================ -.. warning:: MXNet is no longer maintained and has been moved into `Attic `_. As a result, we would encourage you to use other ML frameworks alongise Flower, for example, PyTorch. This tutorial might be removed in future versions of Flower. +.. warning:: MXNet is no longer maintained and has been moved into `Attic `_. As a result, we would encourage you to use other ML frameworks alongside Flower, for example, PyTorch. This tutorial might be removed in future versions of Flower. .. meta:: :description: Check out this Federated Learning quickstart tutorial for using Flower with MXNet to train a Sequential model on MNIST. -In this tutorial, we will learn how to train a :code:`Sequential` model on MNIST using Flower and MXNet. +In this tutorial, we will learn how to train a :code:`Sequential` model on MNIST using Flower and MXNet. -It is recommended to create a virtual environment and run everything within this `virtualenv `_. +It is recommended to create a virtual environment and run everything within this :doc:`virtualenv `. -Our example consists of one *server* and two *clients* all having the same model. +Our example consists of one *server* and two *clients* all having the same model. -*Clients* are responsible for generating individual model parameter updates for the model based on their local datasets. +*Clients* are responsible for generating individual model parameter updates for the model based on their local datasets. These updates are then sent to the *server* which will aggregate them to produce an updated global model. Finally, the *server* sends this improved version of the model back to each *client*. A complete cycle of parameters updates is called a *round*. @@ -35,12 +35,12 @@ Since we want to use MXNet, let's go ahead and install it: Flower Client ------------- -Now that we have all our dependencies installed, let's run a simple distributed training with two clients and one server. Our training procedure and network architecture are based on MXNet´s `Hand-written Digit Recognition tutorial `_. +Now that we have all our dependencies installed, let's run a simple distributed training with two clients and one server. Our training procedure and network architecture are based on MXNet´s `Hand-written Digit Recognition tutorial `_. In a file called :code:`client.py`, import Flower and MXNet related packages: .. code-block:: python - + import flwr as fl import numpy as np @@ -58,7 +58,7 @@ In addition, define the device allocation in MXNet with: DEVICE = [mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()] -We use MXNet to load MNIST, a popular image classification dataset of handwritten digits for machine learning. The MXNet utility :code:`mx.test_utils.get_mnist()` downloads the training and test data. +We use MXNet to load MNIST, a popular image classification dataset of handwritten digits for machine learning. The MXNet utility :code:`mx.test_utils.get_mnist()` downloads the training and test data. .. code-block:: python @@ -72,7 +72,7 @@ We use MXNet to load MNIST, a popular image classification dataset of handwritte val_data = mx.io.NDArrayIter(mnist["test_data"], mnist["test_label"], batch_size) return train_data, val_data -Define the training and loss with MXNet. We train the model by looping over the dataset, measure the corresponding loss, and optimize it. +Define the training and loss with MXNet. We train the model by looping over the dataset, measure the corresponding loss, and optimize it. .. code-block:: python @@ -110,7 +110,7 @@ Define the training and loss with MXNet. We train the model by looping over the return trainings_metric, num_examples -Next, we define the validation of our machine learning model. We loop over the test set and measure both loss and accuracy on the test set. +Next, we define the validation of our machine learning model. We loop over the test set and measure both loss and accuracy on the test set. .. code-block:: python @@ -155,7 +155,7 @@ Our Flower clients will use a simple :code:`Sequential` model: init = nd.random.uniform(shape=(2, 784)) model(init) -After loading the dataset with :code:`load_data()` we perform one forward propagation to initialize the model and model parameters with :code:`model(init)`. Next, we implement a Flower client. +After loading the dataset with :code:`load_data()` we perform one forward propagation to initialize the model and model parameters with :code:`model(init)`. Next, we implement a Flower client. The Flower server interacts with clients through an interface called :code:`Client`. When the server selects a particular client for training, it @@ -207,7 +207,7 @@ They can be implemented in the following way: [accuracy, loss], num_examples = test(model, val_data) print("Evaluation accuracy & loss", accuracy, loss) return float(loss[1]), val_data.batch_size, {"accuracy": float(accuracy[1])} - + We can now create an instance of our class :code:`MNISTClient` and add one line to actually run this client: diff --git a/doc/source/tutorial-quickstart-pytorch.rst b/doc/source/tutorial-quickstart-pytorch.rst index fb77d107b63f..895590808a2b 100644 --- a/doc/source/tutorial-quickstart-pytorch.rst +++ b/doc/source/tutorial-quickstart-pytorch.rst @@ -10,13 +10,13 @@ Quickstart PyTorch .. youtube:: jOmmuzMIQ4c :width: 100% -In this tutorial we will learn how to train a Convolutional Neural Network on CIFAR10 using Flower and PyTorch. +In this tutorial we will learn how to train a Convolutional Neural Network on CIFAR10 using Flower and PyTorch. -First of all, it is recommended to create a virtual environment and run everything within a `virtualenv `_. +First of all, it is recommended to create a virtual environment and run everything within a :doc:`virtualenv `. -Our example consists of one *server* and two *clients* all having the same model. +Our example consists of one *server* and two *clients* all having the same model. -*Clients* are responsible for generating individual weight-updates for the model based on their local datasets. +*Clients* are responsible for generating individual weight-updates for the model based on their local datasets. These updates are then sent to the *server* which will aggregate them to produce a better model. Finally, the *server* sends this improved version of the model back to each *client*. A complete cycle of weight updates is called a *round*. @@ -26,7 +26,7 @@ Now that we have a rough idea of what is going on, let's get started. We first n $ pip install flwr -Since we want to use PyTorch to solve a computer vision task, let's go ahead and install PyTorch and the **torchvision** library: +Since we want to use PyTorch to solve a computer vision task, let's go ahead and install PyTorch and the **torchvision** library: .. code-block:: shell @@ -36,12 +36,12 @@ Since we want to use PyTorch to solve a computer vision task, let's go ahead and Flower Client ------------- -Now that we have all our dependencies installed, let's run a simple distributed training with two clients and one server. Our training procedure and network architecture are based on PyTorch's `Deep Learning with PyTorch `_. +Now that we have all our dependencies installed, let's run a simple distributed training with two clients and one server. Our training procedure and network architecture are based on PyTorch's `Deep Learning with PyTorch `_. In a file called :code:`client.py`, import Flower and PyTorch related packages: .. code-block:: python - + from collections import OrderedDict import torch @@ -59,7 +59,7 @@ In addition, we define the device allocation in PyTorch with: DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -We use PyTorch to load CIFAR10, a popular colored image classification dataset for machine learning. The PyTorch :code:`DataLoader()` downloads the training and test data that are then normalized. +We use PyTorch to load CIFAR10, a popular colored image classification dataset for machine learning. The PyTorch :code:`DataLoader()` downloads the training and test data that are then normalized. .. code-block:: python @@ -75,7 +75,7 @@ We use PyTorch to load CIFAR10, a popular colored image classification dataset f num_examples = {"trainset" : len(trainset), "testset" : len(testset)} return trainloader, testloader, num_examples -Define the loss and optimizer with PyTorch. The training of the dataset is done by looping over the dataset, measure the corresponding loss and optimize it. +Define the loss and optimizer with PyTorch. The training of the dataset is done by looping over the dataset, measure the corresponding loss and optimize it. .. code-block:: python @@ -91,7 +91,7 @@ Define the loss and optimizer with PyTorch. The training of the dataset is done loss.backward() optimizer.step() -Define then the validation of the machine learning network. We loop over the test set and measure the loss and accuracy of the test set. +Define then the validation of the machine learning network. We loop over the test set and measure the loss and accuracy of the test set. .. code-block:: python @@ -139,7 +139,7 @@ The Flower clients will use a simple CNN adapted from 'PyTorch: A 60 Minute Blit net = Net().to(DEVICE) trainloader, testloader, num_examples = load_data() -After loading the data set with :code:`load_data()` we define the Flower interface. +After loading the data set with :code:`load_data()` we define the Flower interface. The Flower server interacts with clients through an interface called :code:`Client`. When the server selects a particular client for training, it @@ -191,10 +191,10 @@ to actually run this client: .. code-block:: python - fl.client.start_numpy_client(server_address="[::]:8080", client=CifarClient()) + fl.client.start_client(server_address="[::]:8080", client=CifarClient().to_client()) That's it for the client. We only have to implement :code:`Client` or -:code:`NumPyClient` and call :code:`fl.client.start_client()` or :code:`fl.client.start_numpy_client()`. The string :code:`"[::]:8080"` tells the client which server to connect to. In our case we can run the server and the client on the same machine, therefore we use +:code:`NumPyClient` and call :code:`fl.client.start_client()`. If you implement a client of type :code:`NumPyClient` you'll need to first call its :code:`to_client()` method. The string :code:`"[::]:8080"` tells the client which server to connect to. In our case we can run the server and the client on the same machine, therefore we use :code:`"[::]:8080"`. If we run a truly federated workload with the server and clients running on different machines, all that needs to change is the :code:`server_address` we point the client at. diff --git a/doc/source/tutorial-quickstart-scikitlearn.rst b/doc/source/tutorial-quickstart-scikitlearn.rst index b33068e975fa..d1d47dc37f19 100644 --- a/doc/source/tutorial-quickstart-scikitlearn.rst +++ b/doc/source/tutorial-quickstart-scikitlearn.rst @@ -7,13 +7,13 @@ Quickstart scikit-learn .. meta:: :description: Check out this Federated Learning quickstart tutorial for using Flower with scikit-learn to train a linear regression model. -In this tutorial, we will learn how to train a :code:`Logistic Regression` model on MNIST using Flower and scikit-learn. +In this tutorial, we will learn how to train a :code:`Logistic Regression` model on MNIST using Flower and scikit-learn. -It is recommended to create a virtual environment and run everything within this `virtualenv `_. +It is recommended to create a virtual environment and run everything within this :doc:`virtualenv `. -Our example consists of one *server* and two *clients* all having the same model. +Our example consists of one *server* and two *clients* all having the same model. -*Clients* are responsible for generating individual model parameter updates for the model based on their local datasets. +*Clients* are responsible for generating individual model parameter updates for the model based on their local datasets. These updates are then sent to the *server* which will aggregate them to produce an updated global model. Finally, the *server* sends this improved version of the model back to each *client*. A complete cycle of parameters updates is called a *round*. @@ -23,7 +23,7 @@ Now that we have a rough idea of what is going on, let's get started. We first n $ pip install flwr -Since we want to use scikt-learn, let's go ahead and install it: +Since we want to use scikit-learn, let's go ahead and install it: .. code-block:: shell @@ -43,7 +43,7 @@ Now that we have all our dependencies installed, let's run a simple distributed However, before setting up the client and server, we will define all functionalities that we need for our federated learning setup within :code:`utils.py`. The :code:`utils.py` contains different functions defining all the machine learning basics: * :code:`get_model_parameters()` - * Returns the paramters of a :code:`sklearn` LogisticRegression model + * Returns the parameters of a :code:`sklearn` LogisticRegression model * :code:`set_model_params()` * Sets the parameters of a :code:`sklean` LogisticRegression model * :code:`set_initial_params()` @@ -59,7 +59,7 @@ Please check out :code:`utils.py` `here `_, a popular image classification dataset of handwritten digits for machine learning. The utility :code:`utils.load_mnist()` downloads the training and test data. The training set is split afterwards into 10 partitions with :code:`utils.partition()`. +We load the MNIST dataset from `OpenML `_, a popular image classification dataset of handwritten digits for machine learning. The utility :code:`utils.load_mnist()` downloads the training and test data. The training set is split afterwards into 10 partitions with :code:`utils.partition()`. .. code-block:: python @@ -145,10 +145,10 @@ to actually run this client: .. code-block:: python - fl.client.start_numpy_client("0.0.0.0:8080", client=MnistClient()) + fl.client.start_client("0.0.0.0:8080", client=MnistClient().to_client()) That's it for the client. We only have to implement :code:`Client` or -:code:`NumPyClient` and call :code:`fl.client.start_client()` or :code:`fl.client.start_numpy_client()`. The string :code:`"0.0.0.0:8080"` tells the client which server to connect to. In our case we can run the server and the client on the same machine, therefore we use +:code:`NumPyClient` and call :code:`fl.client.start_client()`. If you implement a client of type :code:`NumPyClient` you'll need to first call its :code:`to_client()` method. The string :code:`"0.0.0.0:8080"` tells the client which server to connect to. In our case we can run the server and the client on the same machine, therefore we use :code:`"0.0.0.0:8080"`. If we run a truly federated workload with the server and clients running on different machines, all that needs to change is the :code:`server_address` we pass to the client. diff --git a/doc/source/tutorial-quickstart-tensorflow.rst b/doc/source/tutorial-quickstart-tensorflow.rst index 64b2255a9ac6..bd63eb461d21 100644 --- a/doc/source/tutorial-quickstart-tensorflow.rst +++ b/doc/source/tutorial-quickstart-tensorflow.rst @@ -84,11 +84,11 @@ to actually run this client: .. code-block:: python - fl.client.start_numpy_client(server_address="[::]:8080", client=CifarClient()) + fl.client.start_client(server_address="[::]:8080", client=CifarClient().to_client()) That's it for the client. We only have to implement :code:`Client` or -:code:`NumPyClient` and call :code:`fl.client.start_client()` or :code:`fl.client.start_numpy_client()`. The string :code:`"[::]:8080"` tells the client which server to connect to. In our case we can run the server and the client on the same machine, therefore we use +:code:`NumPyClient` and call :code:`fl.client.start_client()`. If you implement a client of type :code:`NumPyClient` you'll need to first call its :code:`to_client()` method. The string :code:`"[::]:8080"` tells the client which server to connect to. In our case we can run the server and the client on the same machine, therefore we use :code:`"[::]:8080"`. If we run a truly federated workload with the server and clients running on different machines, all that needs to change is the :code:`server_address` we point the client at. diff --git a/doc/source/tutorial-quickstart-xgboost.rst b/doc/source/tutorial-quickstart-xgboost.rst index 3a7b356c4d2a..751024db14e4 100644 --- a/doc/source/tutorial-quickstart-xgboost.rst +++ b/doc/source/tutorial-quickstart-xgboost.rst @@ -36,7 +36,7 @@ and then we dive into a more complex example (`full code xgboost-comprehensive < Environment Setup -------------------- -First of all, it is recommended to create a virtual environment and run everything within a `virtualenv `_. +First of all, it is recommended to create a virtual environment and run everything within a :doc:`virtualenv `. We first need to install Flower and Flower Datasets. You can do this by running : @@ -596,7 +596,7 @@ Comprehensive Federated XGBoost Now that you have known how federated XGBoost work with Flower, it's time to run some more comprehensive experiments by customising the experimental settings. In the xgboost-comprehensive example (`full code `_), we provide more options to define various experimental setups, including aggregation strategies, data partitioning and centralised/distributed evaluation. -We also support `Flower simulation `_ making it easy to simulate large client cohorts in a resource-aware manner. +We also support :doc:`Flower simulation ` making it easy to simulate large client cohorts in a resource-aware manner. Let's take a look! Cyclic training diff --git a/doc/source/tutorial-series-build-a-strategy-from-scratch-pytorch.ipynb b/doc/source/tutorial-series-build-a-strategy-from-scratch-pytorch.ipynb index 5b2236468909..c5fc777e7f26 100644 --- a/doc/source/tutorial-series-build-a-strategy-from-scratch-pytorch.ipynb +++ b/doc/source/tutorial-series-build-a-strategy-from-scratch-pytorch.ipynb @@ -7,11 +7,11 @@ "source": [ "# Build a strategy from scratch\n", "\n", - "Welcome to the third part of the Flower federated learning tutorial. In previous parts of this tutorial, we introduced federated learning with PyTorch and Flower ([part 1](https://flower.dev/docs/framework/tutorial-get-started-with-flower-pytorch.html)) and we learned how strategies can be used to customize the execution on both the server and the clients ([part 2](https://flower.dev/docs/framework/tutorial-use-a-federated-learning-strategy-pytorch.html)).\n", + "Welcome to the third part of the Flower federated learning tutorial. In previous parts of this tutorial, we introduced federated learning with PyTorch and Flower ([part 1](https://flower.ai/docs/framework/tutorial-get-started-with-flower-pytorch.html)) and we learned how strategies can be used to customize the execution on both the server and the clients ([part 2](https://flower.ai/docs/framework/tutorial-use-a-federated-learning-strategy-pytorch.html)).\n", "\n", - "In this notebook, we'll continue to customize the federated learning system we built previously by creating a custom version of FedAvg (again, using [Flower](https://flower.dev/) and [PyTorch](https://pytorch.org/)).\n", + "In this notebook, we'll continue to customize the federated learning system we built previously by creating a custom version of FedAvg (again, using [Flower](https://flower.ai/) and [PyTorch](https://pytorch.org/)).\n", "\n", - "> [Star Flower on GitHub](https://github.com/adap/flower) ⭐️ and join the Flower community on Slack to connect, ask questions, and get help: [Join Slack](https://flower.dev/join-slack) 🌼 We'd love to hear from you in the `#introductions` channel! And if anything is unclear, head over to the `#questions` channel.\n", + "> [Star Flower on GitHub](https://github.com/adap/flower) ⭐️ and join the Flower community on Slack to connect, ask questions, and get help: [Join Slack](https://flower.ai/join-slack) 🌼 We'd love to hear from you in the `#introductions` channel! And if anything is unclear, head over to the `#questions` channel.\n", "\n", "Let's build a new `Strategy` from scratch!" ] @@ -489,11 +489,11 @@ "source": [ "## Next steps\n", "\n", - "Before you continue, make sure to join the Flower community on Slack: [Join Slack](https://flower.dev/join-slack/)\n", + "Before you continue, make sure to join the Flower community on Slack: [Join Slack](https://flower.ai/join-slack/)\n", "\n", "There's a dedicated `#questions` channel if you need help, but we'd also love to hear who you are in `#introductions`!\n", "\n", - "The [Flower Federated Learning Tutorial - Part 4](https://flower.dev/docs/framework/tutorial-customize-the-client-pytorch.html) introduces `Client`, the flexible API underlying `NumPyClient`." + "The [Flower Federated Learning Tutorial - Part 4](https://flower.ai/docs/framework/tutorial-customize-the-client-pytorch.html) introduces `Client`, the flexible API underlying `NumPyClient`." ] } ], diff --git a/doc/source/tutorial-series-customize-the-client-pytorch.ipynb b/doc/source/tutorial-series-customize-the-client-pytorch.ipynb index 0ff67de6f51d..dbdd1094173c 100644 --- a/doc/source/tutorial-series-customize-the-client-pytorch.ipynb +++ b/doc/source/tutorial-series-customize-the-client-pytorch.ipynb @@ -7,11 +7,11 @@ "source": [ "# Customize the client\n", "\n", - "Welcome to the fourth part of the Flower federated learning tutorial. In the previous parts of this tutorial, we introduced federated learning with PyTorch and Flower ([part 1](https://flower.dev/docs/framework/tutorial-get-started-with-flower-pytorch.html)), we learned how strategies can be used to customize the execution on both the server and the clients ([part 2](https://flower.dev/docs/framework/tutorial-use-a-federated-learning-strategy-pytorch.html)), and we built our own custom strategy from scratch ([part 3](https://flower.dev/docs/framework/tutorial-build-a-strategy-from-scratch-pytorch.html)).\n", + "Welcome to the fourth part of the Flower federated learning tutorial. In the previous parts of this tutorial, we introduced federated learning with PyTorch and Flower ([part 1](https://flower.ai/docs/framework/tutorial-get-started-with-flower-pytorch.html)), we learned how strategies can be used to customize the execution on both the server and the clients ([part 2](https://flower.ai/docs/framework/tutorial-use-a-federated-learning-strategy-pytorch.html)), and we built our own custom strategy from scratch ([part 3](https://flower.ai/docs/framework/tutorial-build-a-strategy-from-scratch-pytorch.html)).\n", "\n", "In this notebook, we revisit `NumPyClient` and introduce a new baseclass for building clients, simply named `Client`. In previous parts of this tutorial, we've based our client on `NumPyClient`, a convenience class which makes it easy to work with machine learning libraries that have good NumPy interoperability. With `Client`, we gain a lot of flexibility that we didn't have before, but we'll also have to do a few things the we didn't have to do before.\n", "\n", - "> [Star Flower on GitHub](https://github.com/adap/flower) ⭐️ and join the Flower community on Slack to connect, ask questions, and get help: [Join Slack](https://flower.dev/join-slack) 🌼 We'd love to hear from you in the `#introductions` channel! And if anything is unclear, head over to the `#questions` channel.\n", + "> [Star Flower on GitHub](https://github.com/adap/flower) ⭐️ and join the Flower community on Slack to connect, ask questions, and get help: [Join Slack](https://flower.ai/join-slack) 🌼 We'd love to hear from you in the `#introductions` channel! And if anything is unclear, head over to the `#questions` channel.\n", "\n", "Let's go deeper and see what it takes to move from `NumPyClient` to `Client`!" ] @@ -500,7 +500,7 @@ " # We convert our ndarray into a sparse matrix\n", " ndarray = torch.tensor(ndarray).to_sparse_csr()\n", "\n", - " # And send it by utilizng the sparse matrix attributes\n", + " # And send it byutilizing the sparse matrix attributes\n", " # WARNING: NEVER set allow_pickle to true.\n", " # Reason: loading pickled data can execute arbitrary code\n", " # Source: https://numpy.org/doc/stable/reference/generated/numpy.save.html\n", @@ -550,7 +550,7 @@ "source": [ "### Client-side\n", "\n", - "To be able to able to serialize our `ndarray`s into sparse parameters, we will just have to call our custom functions in our `flwr.client.Client`.\n", + "To be able to serialize our `ndarray`s into sparse parameters, we will just have to call our custom functions in our `flwr.client.Client`.\n", "\n", "Indeed, in `get_parameters` we need to serialize the parameters we got from our network using our custom `ndarrays_to_sparse_parameters` defined above.\n", "\n", @@ -813,7 +813,7 @@ " for _, fit_res in results\n", " ]\n", "\n", - " # We serialize the aggregated result using our cutom method\n", + " # We serialize the aggregated result using our custom method\n", " parameters_aggregated = ndarrays_to_sparse_parameters(\n", " aggregate(weights_results)\n", " )\n", @@ -869,16 +869,16 @@ "source": [ "## Next steps\n", "\n", - "Before you continue, make sure to join the Flower community on Slack: [Join Slack](https://flower.dev/join-slack/)\n", + "Before you continue, make sure to join the Flower community on Slack: [Join Slack](https://flower.ai/join-slack/)\n", "\n", "There's a dedicated `#questions` channel if you need help, but we'd also love to hear who you are in `#introductions`!\n", "\n", "This is the final part of the Flower tutorial (for now!), congratulations! You're now well equipped to understand the rest of the documentation. There are many topics we didn't cover in the tutorial, we recommend the following resources:\n", "\n", - "- [Read Flower Docs](https://flower.dev/docs/)\n", + "- [Read Flower Docs](https://flower.ai/docs/)\n", "- [Check out Flower Code Examples](https://github.com/adap/flower/tree/main/examples)\n", - "- [Use Flower Baselines for your research](https://flower.dev/docs/baselines/)\n", - "- [Watch Flower Summit 2023 videos](https://flower.dev/conf/flower-summit-2023/)\n" + "- [Use Flower Baselines for your research](https://flower.ai/docs/baselines/)\n", + "- [Watch Flower Summit 2023 videos](https://flower.ai/conf/flower-summit-2023/)\n" ] } ], diff --git a/doc/source/tutorial-series-get-started-with-flower-pytorch.ipynb b/doc/source/tutorial-series-get-started-with-flower-pytorch.ipynb index bbd916b32375..fab3dafba5e5 100644 --- a/doc/source/tutorial-series-get-started-with-flower-pytorch.ipynb +++ b/doc/source/tutorial-series-get-started-with-flower-pytorch.ipynb @@ -9,9 +9,9 @@ "\n", "Welcome to the Flower federated learning tutorial!\n", "\n", - "In this notebook, we'll build a federated learning system using Flower, [Flower Datasets](https://flower.dev/docs/datasets/) and PyTorch. In part 1, we use PyTorch for the model training pipeline and data loading. In part 2, we continue to federate the PyTorch-based pipeline using Flower.\n", + "In this notebook, we'll build a federated learning system using Flower, [Flower Datasets](https://flower.ai/docs/datasets/) and PyTorch. In part 1, we use PyTorch for the model training pipeline and data loading. In part 2, we continue to federate the PyTorch-based pipeline using Flower.\n", "\n", - "> [Star Flower on GitHub](https://github.com/adap/flower) ⭐️ and join the Flower community on Slack to connect, ask questions, and get help: [Join Slack](https://flower.dev/join-slack) 🌼 We'd love to hear from you in the `#introductions` channel! And if anything is unclear, head over to the `#questions` channel.\n", + "> [Star Flower on GitHub](https://github.com/adap/flower) ⭐️ and join the Flower community on Slack to connect, ask questions, and get help: [Join Slack](https://flower.ai/join-slack) 🌼 We'd love to hear from you in the `#introductions` channel! And if anything is unclear, head over to the `#questions` channel.\n", "\n", "Let's get stated!" ] @@ -83,7 +83,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "It is possible to switch to a runtime that has GPU acceleration enabled (on Google Colab: `Runtime > Change runtime type > Hardware acclerator: GPU > Save`). Note, however, that Google Colab is not always able to offer GPU acceleration. If you see an error related to GPU availability in one of the following sections, consider switching back to CPU-based execution by setting `DEVICE = torch.device(\"cpu\")`. If the runtime has GPU acceleration enabled, you should see the output `Training on cuda`, otherwise it'll say `Training on cpu`." + "It is possible to switch to a runtime that has GPU acceleration enabled (on Google Colab: `Runtime > Change runtime type > Hardware accelerator: GPU > Save`). Note, however, that Google Colab is not always able to offer GPU acceleration. If you see an error related to GPU availability in one of the following sections, consider switching back to CPU-based execution by setting `DEVICE = torch.device(\"cpu\")`. If the runtime has GPU acceleration enabled, you should see the output `Training on cuda`, otherwise it'll say `Training on cpu`." ] }, { @@ -368,14 +368,14 @@ "metadata": {}, "outputs": [], "source": [ - "def get_parameters(net) -> List[np.ndarray]:\n", - " return [val.cpu().numpy() for _, val in net.state_dict().items()]\n", - "\n", - "\n", "def set_parameters(net, parameters: List[np.ndarray]):\n", " params_dict = zip(net.state_dict().keys(), parameters)\n", " state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})\n", - " net.load_state_dict(state_dict, strict=True)" + " net.load_state_dict(state_dict, strict=True)\n", + "\n", + "\n", + "def get_parameters(net) -> List[np.ndarray]:\n", + " return [val.cpu().numpy() for _, val in net.state_dict().items()]" ] }, { @@ -485,10 +485,10 @@ ")\n", "\n", "# Specify the resources each of your clients need. By default, each\n", - "# client will be allocated 1x CPU and 0x CPUs\n", + "# client will be allocated 1x CPU and 0x GPUs\n", "client_resources = {\"num_cpus\": 1, \"num_gpus\": 0.0}\n", "if DEVICE.type == \"cuda\":\n", - " # here we are asigning an entire GPU for each client.\n", + " # here we are assigning an entire GPU for each client.\n", " client_resources = {\"num_cpus\": 1, \"num_gpus\": 1.0}\n", " # Refer to our documentation for more details about Flower Simulations\n", " # and how to setup these `client_resources`.\n", @@ -605,11 +605,11 @@ "source": [ "## Next steps\n", "\n", - "Before you continue, make sure to join the Flower community on Slack: [Join Slack](https://flower.dev/join-slack/)\n", + "Before you continue, make sure to join the Flower community on Slack: [Join Slack](https://flower.ai/join-slack/)\n", "\n", "There's a dedicated `#questions` channel if you need help, but we'd also love to hear who you are in `#introductions`!\n", "\n", - "The [Flower Federated Learning Tutorial - Part 2](https://flower.dev/docs/framework/tutorial-use-a-federated-learning-strategy-pytorch.html) goes into more depth about strategies and all the advanced things you can build with them.\n" + "The [Flower Federated Learning Tutorial - Part 2](https://flower.ai/docs/framework/tutorial-use-a-federated-learning-strategy-pytorch.html) goes into more depth about strategies and all the advanced things you can build with them.\n" ] } ], diff --git a/doc/source/tutorial-series-use-a-federated-learning-strategy-pytorch.ipynb b/doc/source/tutorial-series-use-a-federated-learning-strategy-pytorch.ipynb index 06f53cd8e1b1..e20a8d83f674 100644 --- a/doc/source/tutorial-series-use-a-federated-learning-strategy-pytorch.ipynb +++ b/doc/source/tutorial-series-use-a-federated-learning-strategy-pytorch.ipynb @@ -7,11 +7,11 @@ "source": [ "# Use a federated learning strategy\n", "\n", - "Welcome to the next part of the federated learning tutorial. In previous parts of this tutorial, we introduced federated learning with PyTorch and Flower ([part 1](https://flower.dev/docs/framework/tutorial-get-started-with-flower-pytorch.html)).\n", + "Welcome to the next part of the federated learning tutorial. In previous parts of this tutorial, we introduced federated learning with PyTorch and Flower ([part 1](https://flower.ai/docs/framework/tutorial-get-started-with-flower-pytorch.html)).\n", "\n", - "In this notebook, we'll begin to customize the federated learning system we built in the introductory notebook (again, using [Flower](https://flower.dev/) and [PyTorch](https://pytorch.org/)).\n", + "In this notebook, we'll begin to customize the federated learning system we built in the introductory notebook (again, using [Flower](https://flower.ai/) and [PyTorch](https://pytorch.org/)).\n", "\n", - "> [Star Flower on GitHub](https://github.com/adap/flower) ⭐️ and join the Flower community on Slack to connect, ask questions, and get help: [Join Slack](https://flower.dev/join-slack) 🌼 We'd love to hear from you in the `#introductions` channel! And if anything is unclear, head over to the `#questions` channel.\n", + "> [Star Flower on GitHub](https://github.com/adap/flower) ⭐️ and join the Flower community on Slack to connect, ask questions, and get help: [Join Slack](https://flower.ai/join-slack) 🌼 We'd love to hear from you in the `#introductions` channel! And if anything is unclear, head over to the `#questions` channel.\n", "\n", "Let's move beyond FedAvg with Flower strategies!" ] @@ -614,11 +614,11 @@ "source": [ "## Next steps\n", "\n", - "Before you continue, make sure to join the Flower community on Slack: [Join Slack](https://flower.dev/join-slack/)\n", + "Before you continue, make sure to join the Flower community on Slack: [Join Slack](https://flower.ai/join-slack/)\n", "\n", "There's a dedicated `#questions` channel if you need help, but we'd also love to hear who you are in `#introductions`!\n", "\n", - "The [Flower Federated Learning Tutorial - Part 3](https://flower.dev/docs/framework/tutorial-build-a-strategy-from-scratch-pytorch.html) shows how to build a fully custom `Strategy` from scratch." + "The [Flower Federated Learning Tutorial - Part 3](https://flower.ai/docs/framework/tutorial-build-a-strategy-from-scratch-pytorch.html) shows how to build a fully custom `Strategy` from scratch." ] } ], diff --git a/doc/source/tutorial-series-what-is-federated-learning.ipynb b/doc/source/tutorial-series-what-is-federated-learning.ipynb index 3f7e383b9fbc..7d77d1770457 100755 --- a/doc/source/tutorial-series-what-is-federated-learning.ipynb +++ b/doc/source/tutorial-series-what-is-federated-learning.ipynb @@ -13,7 +13,7 @@ "\n", "🧑‍🏫 This tutorial starts at zero and expects no familiarity with federated learning. Only a basic understanding of data science and Python programming is assumed.\n", "\n", - "> [Star Flower on GitHub](https://github.com/adap/flower) ⭐️ and join the open-source Flower community on Slack to connect, ask questions, and get help: [Join Slack](https://flower.dev/join-slack) 🌼 We'd love to hear from you in the `#introductions` channel! And if anything is unclear, head over to the `#questions` channel.\n", + "> [Star Flower on GitHub](https://github.com/adap/flower) ⭐️ and join the open-source Flower community on Slack to connect, ask questions, and get help: [Join Slack](https://flower.ai/join-slack) 🌼 We'd love to hear from you in the `#introductions` channel! And if anything is unclear, head over to the `#questions` channel.\n", "\n", "Let's get started!" ] @@ -98,7 +98,7 @@ "- Location data from your electric car to make better range prediction\n", "- End-to-end encrypted messages to train better auto-complete models\n", "\n", - "The popularity of privacy-enhancing systems like the [Brave](https://brave.com/) browser or the [Signal](https://signal.org/) messenger shows that users care about privacy. In fact, they choose the privacy-enhancing version over other alternatives, if such an alernative exists. But what can we do to apply machine learning and data science to these cases to utilize private data? After all, these are all areas that would benefit significantly from recent advances in AI." + "The popularity of privacy-enhancing systems like the [Brave](https://brave.com/) browser or the [Signal](https://signal.org/) messenger shows that users care about privacy. In fact, they choose the privacy-enhancing version over other alternatives, if such an alternative exists. But what can we do to apply machine learning and data science to these cases to utilize private data? After all, these are all areas that would benefit significantly from recent advances in AI." ] }, { @@ -217,11 +217,11 @@ "source": [ "## Next steps\n", "\n", - "Before you continue, make sure to join the Flower community on Slack: [Join Slack](https://flower.dev/join-slack/)\n", + "Before you continue, make sure to join the Flower community on Slack: [Join Slack](https://flower.ai/join-slack/)\n", "\n", "There's a dedicated `#questions` channel if you need help, but we'd also love to hear who you are in `#introductions`!\n", "\n", - "The [Flower Federated Learning Tutorial - Part 1](https://flower.dev/docs/framework/tutorial-get-started-with-flower-pytorch.html) shows how to build a simple federated learning system with PyTorch and Flower." + "The [Flower Federated Learning Tutorial - Part 1](https://flower.ai/docs/framework/tutorial-get-started-with-flower-pytorch.html) shows how to build a simple federated learning system with PyTorch and Flower." ] } ], diff --git a/e2e/bare-https/client.py b/e2e/bare-https/client.py index 20a5b4875ddf..b4570b36512d 100644 --- a/e2e/bare-https/client.py +++ b/e2e/bare-https/client.py @@ -25,15 +25,15 @@ def evaluate(self, parameters, config): def client_fn(cid): return FlowerClient().to_client() -flower = fl.flower.Flower( +app = fl.client.ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_numpy_client( + fl.client.start_client( server_address="127.0.0.1:8080", - client=FlowerClient(), + client=FlowerClient().to_client(), root_certificates=Path("certificates/ca.crt").read_bytes(), insecure=False, ) diff --git a/e2e/bare-https/driver.py b/e2e/bare-https/driver.py index 5c44e4c641ae..f7bfeb613f6a 100644 --- a/e2e/bare-https/driver.py +++ b/e2e/bare-https/driver.py @@ -3,7 +3,7 @@ # Start Flower server -hist = fl.driver.start_driver( +hist = fl.server.start_driver( server_address="127.0.0.1:9091", config=fl.server.ServerConfig(num_rounds=3), root_certificates=Path("certificates/ca.crt").read_bytes(), diff --git a/e2e/bare-https/pyproject.toml b/e2e/bare-https/pyproject.toml index 9489a43195f9..3afb7b57a084 100644 --- a/e2e/bare-https/pyproject.toml +++ b/e2e/bare-https/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" name = "bare_https_test" version = "0.1.0" description = "HTTPS-enabled bare Federated Learning test with Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = "^3.8" diff --git a/e2e/bare/client.py b/e2e/bare/client.py index 8e5c3adff5e6..c291fb0963e4 100644 --- a/e2e/bare/client.py +++ b/e2e/bare/client.py @@ -3,6 +3,8 @@ import flwr as fl import numpy as np +from flwr.common import ConfigsRecord + SUBSET_SIZE = 1000 STATE_VAR = 'timestamp' @@ -18,13 +20,15 @@ def get_parameters(self, config): def _record_timestamp_to_state(self): """Record timestamp to client's state.""" t_stamp = datetime.now().timestamp() - if STATE_VAR in self.state.state: - self.state.state[STATE_VAR] += f",{t_stamp}" - else: - self.state.state[STATE_VAR] = str(t_stamp) + value = str(t_stamp) + if STATE_VAR in self.context.state.configs_records.keys(): + value = self.context.state.configs_records[STATE_VAR][STATE_VAR] # type: ignore + value += f",{t_stamp}" + + self.context.state.configs_records[STATE_VAR] = ConfigsRecord({STATE_VAR: value}) def _retrieve_timestamp_from_state(self): - return self.state.state[STATE_VAR] + return self.context.state.configs_records[STATE_VAR][STATE_VAR] def fit(self, parameters, config): model_params = parameters @@ -42,10 +46,10 @@ def evaluate(self, parameters, config): def client_fn(cid): return FlowerClient().to_client() -flower = fl.flower.Flower( +app = fl.client.ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=FlowerClient()) + fl.client.start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client()) diff --git a/e2e/bare/driver.py b/e2e/bare/driver.py index 6bd61e344ad1..defc2ad56213 100644 --- a/e2e/bare/driver.py +++ b/e2e/bare/driver.py @@ -2,7 +2,7 @@ # Start Flower server -hist = fl.driver.start_driver( +hist = fl.server.start_driver( server_address="0.0.0.0:9091", config=fl.server.ServerConfig(num_rounds=3), ) diff --git a/e2e/bare/pyproject.toml b/e2e/bare/pyproject.toml index b9a4028806c3..45ce7ea333af 100644 --- a/e2e/bare/pyproject.toml +++ b/e2e/bare/pyproject.toml @@ -6,8 +6,8 @@ build-backend = "poetry.core.masonry.api" name = "bare_test" version = "0.1.0" description = "Bare Federated Learning test with Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = "^3.8" -flwr = { path = "../../", develop = true, extras = ["simulation"] } +flwr = { path = "../../", develop = true, extras = ["simulation", "rest"] } diff --git a/e2e/fastai/client.py b/e2e/fastai/client.py index 4425fed25277..c4bfb89c2dde 100644 --- a/e2e/fastai/client.py +++ b/e2e/fastai/client.py @@ -53,14 +53,14 @@ def client_fn(cid): return FlowerClient().to_client() -flower = fl.flower.Flower( +app = fl.client.ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_numpy_client( + fl.client.start_client( server_address="127.0.0.1:8080", - client=FlowerClient(), + client=FlowerClient().to_client(), ) diff --git a/e2e/fastai/driver.py b/e2e/fastai/driver.py index 2b1b35d9e89c..cc452ea523ca 100644 --- a/e2e/fastai/driver.py +++ b/e2e/fastai/driver.py @@ -1,6 +1,6 @@ import flwr as fl -hist = fl.driver.start_driver( +hist = fl.server.start_driver( server_address="0.0.0.0:9091", config=fl.server.ServerConfig(num_rounds=3), ) diff --git a/e2e/fastai/pyproject.toml b/e2e/fastai/pyproject.toml index 66fcb393d988..feed31f6d202 100644 --- a/e2e/fastai/pyproject.toml +++ b/e2e/fastai/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" name = "quickstart-fastai" version = "0.1.0" description = "Fastai Federated Learning E2E test with Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.10" diff --git a/e2e/jax/client.py b/e2e/jax/client.py index 495d6a671981..a4e4d1f55117 100644 --- a/e2e/jax/client.py +++ b/e2e/jax/client.py @@ -53,10 +53,10 @@ def evaluate( def client_fn(cid): return FlowerClient().to_client() -flower = fl.flower.Flower( +app = fl.client.ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=FlowerClient()) + fl.client.start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client()) diff --git a/e2e/jax/driver.py b/e2e/jax/driver.py index 2b1b35d9e89c..cc452ea523ca 100644 --- a/e2e/jax/driver.py +++ b/e2e/jax/driver.py @@ -1,6 +1,6 @@ import flwr as fl -hist = fl.driver.start_driver( +hist = fl.server.start_driver( server_address="0.0.0.0:9091", config=fl.server.ServerConfig(num_rounds=3), ) diff --git a/e2e/jax/pyproject.toml b/e2e/jax/pyproject.toml index 3db32ea855eb..9a4af5dee59a 100644 --- a/e2e/jax/pyproject.toml +++ b/e2e/jax/pyproject.toml @@ -2,7 +2,7 @@ name = "jax_example" version = "0.1.0" description = "JAX example training a linear regression model with federated learning" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = "^3.8" diff --git a/e2e/mxnet/.gitignore b/e2e/mxnet/.gitignore deleted file mode 100644 index 10d00b5797e2..000000000000 --- a/e2e/mxnet/.gitignore +++ /dev/null @@ -1 +0,0 @@ -*.gz diff --git a/e2e/mxnet/README.md b/e2e/mxnet/README.md deleted file mode 100644 index 3fa76bac5ce0..000000000000 --- a/e2e/mxnet/README.md +++ /dev/null @@ -1,5 +0,0 @@ -# Flower with MXNet testing - -This directory is used for testing Flower with MXNet by using a simple NN with MNIST data. - -It uses the `FedAvg` strategy. \ No newline at end of file diff --git a/e2e/mxnet/client.py b/e2e/mxnet/client.py deleted file mode 100644 index 2f0b714e708c..000000000000 --- a/e2e/mxnet/client.py +++ /dev/null @@ -1,141 +0,0 @@ -"""Flower client example using MXNet for MNIST classification. - -The code is generally adapted from: - -https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/image/mnist.html -""" - - -import flwr as fl -import numpy as np -import mxnet as mx -from mxnet import nd -from mxnet import gluon -from mxnet.gluon import nn -from mxnet import autograd as ag -import mxnet.ndarray as F - -SUBSET_SIZE = 50 - -# Fixing the random seed -mx.random.seed(42) - -# Setup context to GPU or CPU -DEVICE = [mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()] - -def model(): - net = nn.Sequential() - net.add(nn.Dense(256, activation="relu")) - net.add(nn.Dense(64, activation="relu")) - net.add(nn.Dense(10)) - net.collect_params().initialize() - return net - -def load_data(): - print("Download Dataset") - mnist = mx.test_utils.get_mnist() - batch_size = 100 - train_data = mx.io.NDArrayIter( - mnist["train_data"][:SUBSET_SIZE], mnist["train_label"][:SUBSET_SIZE], batch_size, shuffle=True - ) - val_data = mx.io.NDArrayIter(mnist["test_data"][:10], mnist["test_label"][:10], batch_size) - return train_data, val_data - - -def train(net, train_data, epoch): - trainer = gluon.Trainer(net.collect_params(), "sgd", {"learning_rate": 0.01}) - accuracy_metric = mx.metric.Accuracy() - loss_metric = mx.metric.CrossEntropy() - metrics = mx.metric.CompositeEvalMetric() - for child_metric in [accuracy_metric, loss_metric]: - metrics.add(child_metric) - softmax_cross_entropy_loss = gluon.loss.SoftmaxCrossEntropyLoss() - for i in range(epoch): - train_data.reset() - num_examples = 0 - for batch in train_data: - data = gluon.utils.split_and_load( - batch.data[0], ctx_list=DEVICE, batch_axis=0 - ) - label = gluon.utils.split_and_load( - batch.label[0], ctx_list=DEVICE, batch_axis=0 - ) - outputs = [] - with ag.record(): - for x, y in zip(data, label): - z = net(x) - loss = softmax_cross_entropy_loss(z, y) - loss.backward() - outputs.append(z.softmax()) - num_examples += len(x) - metrics.update(label, outputs) - trainer.step(batch.data[0].shape[0]) - trainings_metric = metrics.get_name_value() - print("Accuracy & loss at epoch %d: %s" % (i, trainings_metric)) - return trainings_metric, num_examples - - -def test(net, val_data): - accuracy_metric = mx.metric.Accuracy() - loss_metric = mx.metric.CrossEntropy() - metrics = mx.metric.CompositeEvalMetric() - for child_metric in [accuracy_metric, loss_metric]: - metrics.add(child_metric) - val_data.reset() - num_examples = 0 - for batch in val_data: - data = gluon.utils.split_and_load(batch.data[0], ctx_list=DEVICE, batch_axis=0) - label = gluon.utils.split_and_load( - batch.label[0], ctx_list=DEVICE, batch_axis=0 - ) - outputs = [] - for x in data: - outputs.append(net(x).softmax()) - num_examples += len(x) - metrics.update(label, outputs) - metrics.update(label, outputs) - return metrics.get_name_value(), num_examples - -train_data, val_data = load_data() - -model = model() -init = nd.random.uniform(shape=(2, 784)) -model(init) - -# Flower Client -class FlowerClient(fl.client.NumPyClient): - def get_parameters(self, config): - param = [] - for val in model.collect_params(".*weight").values(): - p = val.data() - param.append(p.asnumpy()) - return param - - def set_parameters(self, parameters): - params = zip(model.collect_params(".*weight").keys(), parameters) - for key, value in params: - model.collect_params().setattr(key, value) - - def fit(self, parameters, config): - self.set_parameters(parameters) - [accuracy, loss], num_examples = train(model, train_data, epoch=2) - results = {"accuracy": float(accuracy[1]), "loss": float(loss[1])} - return self.get_parameters(config={}), num_examples, results - - def evaluate(self, parameters, config): - self.set_parameters(parameters) - [accuracy, loss], num_examples = test(model, val_data) - print("Evaluation accuracy & loss", accuracy, loss) - return float(loss[1]), num_examples, {"accuracy": float(accuracy[1])} - - -def client_fn(cid): - return FlowerClient().to_client() - -flower = fl.flower.Flower( - client_fn=client_fn, -) - -if __name__ == "__main__": - # Start Flower client - fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=FlowerClient()) diff --git a/e2e/mxnet/driver.py b/e2e/mxnet/driver.py deleted file mode 100644 index 2b1b35d9e89c..000000000000 --- a/e2e/mxnet/driver.py +++ /dev/null @@ -1,7 +0,0 @@ -import flwr as fl - -hist = fl.driver.start_driver( - server_address="0.0.0.0:9091", - config=fl.server.ServerConfig(num_rounds=3), -) -assert (hist.losses_distributed[0][1] / hist.losses_distributed[-1][1]) >= 1 diff --git a/e2e/mxnet/pyproject.toml b/e2e/mxnet/pyproject.toml deleted file mode 100644 index 71bd0e6374bd..000000000000 --- a/e2e/mxnet/pyproject.toml +++ /dev/null @@ -1,15 +0,0 @@ -[build-system] -requires = ["poetry-core>=1.4.0"] -build-backend = "poetry.core.masonry.api" - -[tool.poetry] -name = "mxnet_example" -version = "0.1.0" -description = "MXNet example with MNIST and CNN" -authors = ["The Flower Authors "] - -[tool.poetry.dependencies] -python = "^3.8" -flwr = { path = "../../", develop = true, extras = ["simulation"] } -mxnet = "^1.7.0" -numpy = "1.23.1" diff --git a/e2e/mxnet/simulation.py b/e2e/mxnet/simulation.py deleted file mode 100644 index 5f0e5334bd08..000000000000 --- a/e2e/mxnet/simulation.py +++ /dev/null @@ -1,11 +0,0 @@ -import flwr as fl - -from client import client_fn - -hist = fl.simulation.start_simulation( - client_fn=client_fn, - num_clients=2, - config=fl.server.ServerConfig(num_rounds=3), -) - -assert hist.losses_distributed[-1][1] == 0 or (hist.losses_distributed[0][1] / hist.losses_distributed[-1][1]) >= 0.98 diff --git a/e2e/opacus/client.py b/e2e/opacus/client.py index 2e5c363381fa..00437a31233c 100644 --- a/e2e/opacus/client.py +++ b/e2e/opacus/client.py @@ -137,12 +137,12 @@ def client_fn(cid): model = Net() return FlowerClient(model).to_client() -flower = fl.flower.Flower( +app = fl.client.ClientApp( client_fn=client_fn, ) if __name__ == "__main__": - fl.client.start_numpy_client( + fl.client.start_client( server_address="127.0.0.1:8080", - client=FlowerClient(model) + client=FlowerClient(model).to_client() ) diff --git a/e2e/opacus/driver.py b/e2e/opacus/driver.py index 5a0309914ee0..75acd9ccea24 100644 --- a/e2e/opacus/driver.py +++ b/e2e/opacus/driver.py @@ -1,6 +1,6 @@ import flwr as fl -hist = fl.driver.start_driver( +hist = fl.server.start_driver( server_address="0.0.0.0:9091", config=fl.server.ServerConfig(num_rounds=3), ) diff --git a/e2e/opacus/pyproject.toml b/e2e/opacus/pyproject.toml index 1aee7c6ec6d9..ab4a727cc00b 100644 --- a/e2e/opacus/pyproject.toml +++ b/e2e/opacus/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" name = "opacus_e2e" version = "0.1.0" description = "Opacus E2E testing" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = "^3.8" diff --git a/e2e/pandas/client.py b/e2e/pandas/client.py index 5b8670091cb3..0ecd75df3ae8 100644 --- a/e2e/pandas/client.py +++ b/e2e/pandas/client.py @@ -1,4 +1,3 @@ -import warnings from typing import Dict, List, Tuple import numpy as np @@ -36,13 +35,13 @@ def fit( def client_fn(cid): return FlowerClient().to_client() -flower = fl.flower.Flower( +app = fl.client.ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_numpy_client( + fl.client.start_client( server_address="127.0.0.1:8080", - client=FlowerClient(), + client=FlowerClient().to_client(), ) diff --git a/e2e/pandas/driver.py b/e2e/pandas/driver.py index b33e1e54f4a0..f5dc74c9f3f8 100644 --- a/e2e/pandas/driver.py +++ b/e2e/pandas/driver.py @@ -3,7 +3,7 @@ from strategy import FedAnalytics # Start Flower server -hist = fl.driver.start_driver( +hist = fl.server.start_driver( server_address="0.0.0.0:9091", config=fl.server.ServerConfig(num_rounds=1), strategy=FedAnalytics(), diff --git a/e2e/pandas/pyproject.toml b/e2e/pandas/pyproject.toml index b90037a31068..416dfeec3460 100644 --- a/e2e/pandas/pyproject.toml +++ b/e2e/pandas/pyproject.toml @@ -7,7 +7,7 @@ name = "quickstart-pandas" version = "0.1.0" description = "Pandas Federated Analytics Quickstart with Flower" authors = ["Ragy Haddad "] -maintainers = ["The Flower Authors "] +maintainers = ["The Flower Authors "] [tool.poetry.dependencies] python = "^3.8" diff --git a/e2e/pandas/simulation.py b/e2e/pandas/simulation.py index 91af84062712..b548b5ebb760 100644 --- a/e2e/pandas/simulation.py +++ b/e2e/pandas/simulation.py @@ -1,12 +1,8 @@ import flwr as fl -from client import FlowerClient +from client import client_fn from strategy import FedAnalytics -def client_fn(cid): - _ = cid - return FlowerClient() - hist = fl.simulation.start_simulation( client_fn=client_fn, num_clients=2, diff --git a/e2e/pytorch-lightning/client.py b/e2e/pytorch-lightning/client.py index 71b178eca8c3..fde550e31c08 100644 --- a/e2e/pytorch-lightning/client.py +++ b/e2e/pytorch-lightning/client.py @@ -55,7 +55,7 @@ def client_fn(cid): # Flower client return FlowerClient(model, train_loader, val_loader, test_loader).to_client() -flower = fl.flower.Flower( +app = fl.client.ClientApp( client_fn=client_fn, ) @@ -65,8 +65,8 @@ def main() -> None: train_loader, val_loader, test_loader = mnist.load_data() # Flower client - client = FlowerClient(model, train_loader, val_loader, test_loader) - fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=client) + client = FlowerClient(model, train_loader, val_loader, test_loader).to_client() + fl.client.start_client(server_address="127.0.0.1:8080", client=client) if __name__ == "__main__": diff --git a/e2e/pytorch-lightning/driver.py b/e2e/pytorch-lightning/driver.py index 2b1b35d9e89c..cc452ea523ca 100644 --- a/e2e/pytorch-lightning/driver.py +++ b/e2e/pytorch-lightning/driver.py @@ -1,6 +1,6 @@ import flwr as fl -hist = fl.driver.start_driver( +hist = fl.server.start_driver( server_address="0.0.0.0:9091", config=fl.server.ServerConfig(num_rounds=3), ) diff --git a/e2e/pytorch-lightning/pyproject.toml b/e2e/pytorch-lightning/pyproject.toml index e79eb72a56df..88cddddf500f 100644 --- a/e2e/pytorch-lightning/pyproject.toml +++ b/e2e/pytorch-lightning/pyproject.toml @@ -6,10 +6,10 @@ build-backend = "poetry.masonry.api" name = "quickstart-pytorch-lightning" version = "0.1.0" description = "Federated Learning E2E test with Flower and PyTorch Lightning" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = "^3.8" flwr = { path = "../../", develop = true, extras = ["simulation"] } -pytorch-lightning = "1.6.0" +pytorch-lightning = "2.1.3" torchvision = "0.14.1" diff --git a/e2e/pytorch/client.py b/e2e/pytorch/client.py index d180ad5d4eca..1fd07763148e 100644 --- a/e2e/pytorch/client.py +++ b/e2e/pytorch/client.py @@ -11,6 +11,7 @@ from tqdm import tqdm import flwr as fl +from flwr.common import ConfigsRecord # ############################################################################# # 1. Regular PyTorch pipeline: nn.Module, train, test, and DataLoader @@ -95,14 +96,15 @@ def get_parameters(self, config): def _record_timestamp_to_state(self): """Record timestamp to client's state.""" t_stamp = datetime.now().timestamp() - if STATE_VAR in self.state.state: - self.state.state[STATE_VAR] += f",{t_stamp}" - else: - self.state.state[STATE_VAR] = str(t_stamp) + value = str(t_stamp) + if STATE_VAR in self.context.state.configs_records.keys(): + value = self.context.state.configs_records[STATE_VAR][STATE_VAR] # type: ignore + value += f",{t_stamp}" + + self.context.state.configs_records[STATE_VAR] = ConfigsRecord({STATE_VAR: value}) def _retrieve_timestamp_from_state(self): - return self.state.state[STATE_VAR] - + return self.context.state.configs_records[STATE_VAR][STATE_VAR] def fit(self, parameters, config): set_parameters(net, parameters) train(net, trainloader, epochs=1) @@ -124,14 +126,14 @@ def set_parameters(model, parameters): def client_fn(cid): return FlowerClient().to_client() -flower = fl.flower.Flower( +app = fl.client.ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_numpy_client( + fl.client.start_client( server_address="127.0.0.1:8080", - client=FlowerClient(), + client=FlowerClient().to_client(), ) diff --git a/e2e/pytorch/driver.py b/e2e/pytorch/driver.py index ca860ea47b2d..2ea4de69a62b 100644 --- a/e2e/pytorch/driver.py +++ b/e2e/pytorch/driver.py @@ -18,7 +18,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=weighted_average) # Start Flower server -hist = fl.driver.start_driver( +hist = fl.server.start_driver( server_address="0.0.0.0:9091", config=fl.server.ServerConfig(num_rounds=3), strategy=strategy, diff --git a/e2e/pytorch/pyproject.toml b/e2e/pytorch/pyproject.toml index 4aa326330c62..e538f1437df6 100644 --- a/e2e/pytorch/pyproject.toml +++ b/e2e/pytorch/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" name = "quickstart-pytorch" version = "0.1.0" description = "PyTorch Federated Learning Quickstart with Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = "^3.8" diff --git a/e2e/scikit-learn/client.py b/e2e/scikit-learn/client.py index fdca96c1697a..e073d3cb2748 100644 --- a/e2e/scikit-learn/client.py +++ b/e2e/scikit-learn/client.py @@ -46,10 +46,10 @@ def evaluate(self, parameters, config): # type: ignore def client_fn(cid): return FlowerClient().to_client() -flower = fl.flower.Flower( +app = fl.client.ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_numpy_client(server_address="0.0.0.0:8080", client=FlowerClient()) + fl.client.start_client(server_address="0.0.0.0:8080", client=FlowerClient().to_client()) diff --git a/e2e/scikit-learn/driver.py b/e2e/scikit-learn/driver.py index 032a2f7a0dc6..29051d02c6b6 100644 --- a/e2e/scikit-learn/driver.py +++ b/e2e/scikit-learn/driver.py @@ -36,7 +36,7 @@ def evaluate(server_round, parameters: fl.common.NDArrays, config): evaluate_fn=get_evaluate_fn(model), on_fit_config_fn=fit_round, ) - hist = fl.driver.start_driver( + hist = fl.server.start_driver( server_address="0.0.0.0:9091", strategy=strategy, config=fl.server.ServerConfig(num_rounds=3), diff --git a/e2e/scikit-learn/pyproject.toml b/e2e/scikit-learn/pyproject.toml index 372ae1218bbe..50c07d31add7 100644 --- a/e2e/scikit-learn/pyproject.toml +++ b/e2e/scikit-learn/pyproject.toml @@ -7,8 +7,8 @@ name = "sklearn-mnist" version = "0.1.0" description = "Federated learning with scikit-learn and Flower" authors = [ - "The Flower Authors ", - "Kaushik Amar Das " + "The Flower Authors ", + "Kaushik Amar Das ", ] [tool.poetry.dependencies] diff --git a/e2e/scikit-learn/utils.py b/e2e/scikit-learn/utils.py index d6804fdcce12..2b8dcf8655ee 100644 --- a/e2e/scikit-learn/utils.py +++ b/e2e/scikit-learn/utils.py @@ -10,7 +10,7 @@ def get_model_parameters(model: LogisticRegression) -> LogRegParams: - """Returns the paramters of a sklearn LogisticRegression model.""" + """Returns the parameters of a sklearn LogisticRegression model.""" if model.fit_intercept: params = [ model.coef_, diff --git a/e2e/strategies/client.py b/e2e/strategies/client.py index eb4598cb5439..3b49f770dc6b 100644 --- a/e2e/strategies/client.py +++ b/e2e/strategies/client.py @@ -47,11 +47,11 @@ def client_fn(cid): return FlowerClient().to_client() -flower = fl.flower.Flower( +app = fl.client.ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=FlowerClient()) + fl.client.start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client()) diff --git a/e2e/strategies/pyproject.toml b/e2e/strategies/pyproject.toml index edfb16de59a6..5cc74b20fa24 100644 --- a/e2e/strategies/pyproject.toml +++ b/e2e/strategies/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" name = "quickstart_tensorflow" version = "0.1.0" description = "Keras Federated Learning Quickstart with Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" diff --git a/e2e/tabnet/client.py b/e2e/tabnet/client.py index 3c10df0c79f1..0290ba4629de 100644 --- a/e2e/tabnet/client.py +++ b/e2e/tabnet/client.py @@ -81,10 +81,10 @@ def evaluate(self, parameters, config): def client_fn(cid): return FlowerClient().to_client() -flower = fl.flower.Flower( +app = fl.client.ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=FlowerClient()) + fl.client.start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client()) diff --git a/e2e/tabnet/driver.py b/e2e/tabnet/driver.py index 2b1b35d9e89c..cc452ea523ca 100644 --- a/e2e/tabnet/driver.py +++ b/e2e/tabnet/driver.py @@ -1,6 +1,6 @@ import flwr as fl -hist = fl.driver.start_driver( +hist = fl.server.start_driver( server_address="0.0.0.0:9091", config=fl.server.ServerConfig(num_rounds=3), ) diff --git a/e2e/tabnet/pyproject.toml b/e2e/tabnet/pyproject.toml index 7b47ffeb1470..b1abf382a24a 100644 --- a/e2e/tabnet/pyproject.toml +++ b/e2e/tabnet/pyproject.toml @@ -6,13 +6,13 @@ build-backend = "poetry.core.masonry.api" name = "quickstart-tabnet" version = "0.1.0" description = "Tabnet Federated Learning Quickstart with Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" flwr = { path = "../../", develop = true, extras = ["simulation"] } -tensorflow-cpu = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\""} -tensorflow-macos = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\""} +tensorflow-cpu = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\"" } +tensorflow-macos = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\"" } tensorflow_datasets = "4.9.2" tensorflow-io-gcs-filesystem = "<0.35.0" tabnet = "0.1.6" diff --git a/e2e/tensorflow/client.py b/e2e/tensorflow/client.py index 4ad2d5ebda57..10ee91136241 100644 --- a/e2e/tensorflow/client.py +++ b/e2e/tensorflow/client.py @@ -34,10 +34,10 @@ def evaluate(self, parameters, config): def client_fn(cid): return FlowerClient().to_client() -flower = fl.flower.Flower( +app = fl.client.ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=FlowerClient()) + fl.client.start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client()) diff --git a/e2e/tensorflow/driver.py b/e2e/tensorflow/driver.py index ca860ea47b2d..2ea4de69a62b 100644 --- a/e2e/tensorflow/driver.py +++ b/e2e/tensorflow/driver.py @@ -18,7 +18,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=weighted_average) # Start Flower server -hist = fl.driver.start_driver( +hist = fl.server.start_driver( server_address="0.0.0.0:9091", config=fl.server.ServerConfig(num_rounds=3), strategy=strategy, diff --git a/e2e/tensorflow/pyproject.toml b/e2e/tensorflow/pyproject.toml index 467e69a026b3..a7dbfe2305db 100644 --- a/e2e/tensorflow/pyproject.toml +++ b/e2e/tensorflow/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" name = "quickstart-tensorflow" version = "0.1.0" description = "Keras Federated Learning Quickstart with Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" diff --git a/e2e/test_driver.sh b/e2e/test_driver.sh index 32314bd22533..3d4864a1b0fb 100755 --- a/e2e/test_driver.sh +++ b/e2e/test_driver.sh @@ -13,13 +13,34 @@ case "$1" in ;; esac -timeout 2m flower-server $server_arg & +case "$2" in + rest) + rest_arg="--rest" + server_address="http://localhost:9093" + db_arg="--database :flwr-in-memory-state:" + ;; + sqlite) + rest_arg="" + server_address="127.0.0.1:9092" + db_arg="--database $(date +%s).db" + ;; + *) + rest_arg="" + server_address="127.0.0.1:9092" + db_arg="--database :flwr-in-memory-state:" + ;; +esac + +timeout 2m flower-superlink $server_arg $db_arg $rest_arg & +sl_pid=$! sleep 3 -timeout 2m flower-client client:flower $client_arg --server 127.0.0.1:9092 & +timeout 2m flower-client-app client:app $client_arg $rest_arg --server $server_address & +cl1_pid=$! sleep 3 -timeout 2m flower-client client:flower $client_arg --server 127.0.0.1:9092 & +timeout 2m flower-client-app client:app $client_arg $rest_arg --server $server_address & +cl2_pid=$! sleep 3 timeout 2m python driver.py & @@ -29,7 +50,7 @@ wait $pid res=$? if [[ "$res" = "0" ]]; - then echo "Training worked correctly" && pkill flower-client && pkill flower-server; + then echo "Training worked correctly"; kill $cl1_pid; kill $cl2_pid; kill $sl_pid; else echo "Training had an issue" && exit 1; fi diff --git a/examples/advanced-pytorch/README.md b/examples/advanced-pytorch/README.md index 9101105b2618..c1ba85b95879 100644 --- a/examples/advanced-pytorch/README.md +++ b/examples/advanced-pytorch/README.md @@ -1,6 +1,6 @@ # Advanced Flower Example (PyTorch) -This example demonstrates an advanced federated learning setup using Flower with PyTorch. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) and it differs from the quickstart example in the following ways: +This example demonstrates an advanced federated learning setup using Flower with PyTorch. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) and it differs from the quickstart example in the following ways: - 10 clients (instead of just 2) - Each client holds a local dataset of 5000 training examples and 1000 test examples (note that using the `run.sh` script will only select 10 data samples by default, as the `--toy` argument is set). diff --git a/examples/advanced-pytorch/client.py b/examples/advanced-pytorch/client.py index 0eb457d68645..d4c8abe3d404 100644 --- a/examples/advanced-pytorch/client.py +++ b/examples/advanced-pytorch/client.py @@ -150,9 +150,8 @@ def main() -> None: trainset = trainset.select(range(10)) testset = testset.select(range(10)) # Start Flower client - client = CifarClient(trainset, testset, device, args.model) - - fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=client) + client = CifarClient(trainset, testset, device, args.model).to_client() + fl.client.start_client(server_address="127.0.0.1:8080", client=client) if __name__ == "__main__": diff --git a/examples/advanced-pytorch/pyproject.toml b/examples/advanced-pytorch/pyproject.toml index 89fd5a32a89e..b846a6054cc8 100644 --- a/examples/advanced-pytorch/pyproject.toml +++ b/examples/advanced-pytorch/pyproject.toml @@ -7,8 +7,8 @@ name = "advanced-pytorch" version = "0.1.0" description = "Advanced Flower/PyTorch Example" authors = [ - "The Flower Authors ", - "Kaushik Amar Das " + "The Flower Authors ", + "Kaushik Amar Das ", ] [tool.poetry.dependencies] diff --git a/examples/advanced-pytorch/utils.py b/examples/advanced-pytorch/utils.py index 186f079010dc..4a0f6918cdd6 100644 --- a/examples/advanced-pytorch/utils.py +++ b/examples/advanced-pytorch/utils.py @@ -9,10 +9,10 @@ warnings.filterwarnings("ignore") -def load_partition(node_id, toy: bool = False): +def load_partition(partition_id, toy: bool = False): """Load partition CIFAR10 data.""" fds = FederatedDataset(dataset="cifar10", partitioners={"train": 10}) - partition = fds.load_partition(node_id) + partition = fds.load_partition(partition_id) # Divide data on each node: 80% train, 20% test partition_train_test = partition.train_test_split(test_size=0.2) partition_train_test = partition_train_test.with_transform(apply_transforms) diff --git a/examples/advanced-tensorflow/README.md b/examples/advanced-tensorflow/README.md index b21c0d2545ca..94707b5cbc98 100644 --- a/examples/advanced-tensorflow/README.md +++ b/examples/advanced-tensorflow/README.md @@ -1,6 +1,6 @@ # Advanced Flower Example (TensorFlow/Keras) -This example demonstrates an advanced federated learning setup using Flower with TensorFlow/Keras. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) and it differs from the quickstart example in the following ways: +This example demonstrates an advanced federated learning setup using Flower with TensorFlow/Keras. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) and it differs from the quickstart example in the following ways: - 10 clients (instead of just 2) - Each client holds a local dataset of 1/10 of the train datasets and 80% is training examples and 20% as test examples (note that by default only a small subset of this data is used when running the `run.sh` script) @@ -57,7 +57,7 @@ pip install -r requirements.txt ## Run Federated Learning with TensorFlow/Keras and Flower -The included `run.sh` will call a script to generate certificates (which will be used by server and clients), start the Flower server (using `server.py`), sleep for 10 seconds to ensure the the server is up, and then start 10 Flower clients (using `client.py`). You can simply start everything in a terminal as follows: +The included `run.sh` will call a script to generate certificates (which will be used by server and clients), start the Flower server (using `server.py`), sleep for 10 seconds to ensure the server is up, and then start 10 Flower clients (using `client.py`). You can simply start everything in a terminal as follows: ```shell # Once you have activated your environment diff --git a/examples/advanced-tensorflow/client.py b/examples/advanced-tensorflow/client.py index f42c93784fc6..17d1d2306270 100644 --- a/examples/advanced-tensorflow/client.py +++ b/examples/advanced-tensorflow/client.py @@ -106,9 +106,9 @@ def main() -> None: x_test, y_test = x_test[:10], y_test[:10] # Start Flower client - client = CifarClient(model, x_train, y_train, x_test, y_test) + client = CifarClient(model, x_train, y_train, x_test, y_test).to_client() - fl.client.start_numpy_client( + fl.client.start_client( server_address="127.0.0.1:8080", client=client, root_certificates=Path(".cache/certificates/ca.crt").read_bytes(), diff --git a/examples/advanced-tensorflow/pyproject.toml b/examples/advanced-tensorflow/pyproject.toml index 2f16d8a15584..02bd923129a4 100644 --- a/examples/advanced-tensorflow/pyproject.toml +++ b/examples/advanced-tensorflow/pyproject.toml @@ -6,11 +6,11 @@ build-backend = "poetry.core.masonry.api" name = "advanced-tensorflow" version = "0.1.0" description = "Advanced Flower/TensorFlow Example" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" flwr = ">=1.0,<2.0" flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" } -tensorflow-cpu = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\""} -tensorflow-macos = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\""} +tensorflow-cpu = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\"" } +tensorflow-macos = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\"" } diff --git a/examples/android/client/app/src/main/java/flwr/android_client/MainActivity.java b/examples/android/client/app/src/main/java/flwr/android_client/MainActivity.java index cbf804140954..4de369a7e684 100644 --- a/examples/android/client/app/src/main/java/flwr/android_client/MainActivity.java +++ b/examples/android/client/app/src/main/java/flwr/android_client/MainActivity.java @@ -60,7 +60,7 @@ public void onChanged(List workInfos) { if (workInfos.size() > 0) { WorkInfo info = workInfos.get(0); int progress = info.getProgress().getInt("progress", -1); - // You can recieve any message from the Worker Thread + // You can receive any message from the Worker Thread refreshRecyclerView(); } } diff --git a/examples/android/pyproject.toml b/examples/android/pyproject.toml index 2b9cd8c978a7..0371f7208292 100644 --- a/examples/android/pyproject.toml +++ b/examples/android/pyproject.toml @@ -6,10 +6,10 @@ build-backend = "poetry.masonry.api" name = "android_flwr_tensorflow" version = "0.1.0" description = "Android Example" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" flwr = ">=1.0,<2.0" -tensorflow-cpu = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\""} -tensorflow-macos = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\""} +tensorflow-cpu = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\"" } +tensorflow-macos = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\"" } diff --git a/examples/app-pytorch/README.md b/examples/app-pytorch/README.md new file mode 100644 index 000000000000..14de3c7d632e --- /dev/null +++ b/examples/app-pytorch/README.md @@ -0,0 +1,75 @@ +# Flower App (PyTorch) 🧪 + +> 🧪 = This example covers experimental features that might change in future versions of Flower +> Please consult the regular PyTorch code examples ([quickstart](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch), [advanced](https://github.com/adap/flower/tree/main/examples/advanced-pytorch)) to learn how to use Flower with PyTorch. + +The following steps describe how to start a long-running Flower server (SuperLink) and then run a Flower App (consisting of a `ClientApp` and a `ServerApp`). + +## Preconditions + +Let's assume the following project structure: + +```bash +$ tree . +. +├── client.py # <-- contains `ClientApp` +├── server.py # <-- contains `ServerApp` +├── server_workflow.py # <-- contains `ServerApp` with workflow +├── server_custom.py # <-- contains `ServerApp` with custom main function +├── task.py # <-- task-specific code (model, data) +└── requirements.txt # <-- dependencies +``` + +## Install dependencies + +```bash +pip install -r requirements.txt +``` + +## Run a simulation + +```bash +flower-simulation --server-app server:app --client-app client:app --num-supernodes 2 +``` + +## Run a deployment + +### Start the long-running Flower server (SuperLink) + +```bash +flower-superlink --insecure +``` + +### Start the long-running Flower client (SuperNode) + +In a new terminal window, start the first long-running Flower client: + +```bash +flower-client-app client:app --insecure +``` + +In yet another new terminal window, start the second long-running Flower client: + +```bash +flower-client-app client:app --insecure +``` + +### Run the Flower App + +With both the long-running server (SuperLink) and two clients (SuperNode) up and running, we can now run the actual Flower App: + +```bash +flower-server-app server:app --insecure +``` + +Or, to try the workflow example, run: + +```bash +flower-server-app server_workflow:app --insecure +``` + +Or, to try the custom server function example, run: + +```bash +flower-server-app server_custom:app --insecure +``` diff --git a/examples/app-pytorch/client.py b/examples/app-pytorch/client.py new file mode 100644 index 000000000000..ebbe977ecab1 --- /dev/null +++ b/examples/app-pytorch/client.py @@ -0,0 +1,51 @@ +from flwr.client import ClientApp, NumPyClient + +from task import ( + Net, + DEVICE, + load_data, + get_weights, + set_weights, + train, + test, +) + + +# Load model and data (simple CNN, CIFAR-10) +net = Net().to(DEVICE) +trainloader, testloader = load_data() + + +# Define FlowerClient and client_fn +class FlowerClient(NumPyClient): + + def fit(self, parameters, config): + set_weights(net, parameters) + results = train(net, trainloader, testloader, epochs=1, device=DEVICE) + return get_weights(net), len(trainloader.dataset), results + + def evaluate(self, parameters, config): + set_weights(net, parameters) + loss, accuracy = test(net, testloader) + return loss, len(testloader.dataset), {"accuracy": accuracy} + + +def client_fn(cid: str): + """Create and return an instance of Flower `Client`.""" + return FlowerClient().to_client() + + +# Flower ClientApp +app = ClientApp( + client_fn=client_fn, +) + + +# Legacy mode +if __name__ == "__main__": + from flwr.client import start_client + + start_client( + server_address="127.0.0.1:8080", + client=FlowerClient().to_client(), + ) diff --git a/examples/app-pytorch/client_low_level.py b/examples/app-pytorch/client_low_level.py new file mode 100644 index 000000000000..feea1ee658fe --- /dev/null +++ b/examples/app-pytorch/client_low_level.py @@ -0,0 +1,35 @@ +from flwr.client import ClientApp +from flwr.common import Message, Context + + +def hello_world_mod(msg, ctx, call_next) -> Message: + print("Hello, ...[pause for dramatic effect]...") + out = call_next(msg, ctx) + print("...[pause was long enough]... World!") + return out + + +# Flower ClientApp +app = ClientApp( + mods=[ + hello_world_mod, + ], +) + + +@app.train() +def train(msg: Message, ctx: Context): + print("`train` is not implemented, echoing original message") + return msg.create_reply(msg.content, ttl="") + + +@app.evaluate() +def eval(msg: Message, ctx: Context): + print("`evaluate` is not implemented, echoing original message") + return msg.create_reply(msg.content, ttl="") + + +@app.query() +def query(msg: Message, ctx: Context): + print("`query` is not implemented, echoing original message") + return msg.create_reply(msg.content, ttl="") diff --git a/examples/app-pytorch/pyproject.toml b/examples/app-pytorch/pyproject.toml new file mode 100644 index 000000000000..e47dd2db949d --- /dev/null +++ b/examples/app-pytorch/pyproject.toml @@ -0,0 +1,17 @@ +[build-system] +requires = ["poetry-core>=1.4.0"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "app-pytorch" +version = "0.1.0" +description = "Multi-Tenant Federated Learning with Flower and PyTorch" +authors = ["The Flower Authors "] + +[tool.poetry.dependencies] +python = "^3.8" +# Mandatory dependencies +flwr-nightly = { version = "1.8.0.dev20240309", extras = ["simulation"] } +flwr-datasets = { version = "0.0.2", extras = ["vision"] } +torch = "2.2.1" +torchvision = "0.17.1" diff --git a/examples/app-pytorch/requirements.txt b/examples/app-pytorch/requirements.txt new file mode 100644 index 000000000000..016a84043cbe --- /dev/null +++ b/examples/app-pytorch/requirements.txt @@ -0,0 +1,4 @@ +flwr-nightly[simulation]==1.8.0.dev20240309 +flwr-datasets[vision]==0.0.2 +torch==2.2.1 +torchvision==0.17.1 diff --git a/examples/mt-pytorch/start_server.py b/examples/app-pytorch/server.py similarity index 64% rename from examples/mt-pytorch/start_server.py rename to examples/app-pytorch/server.py index d96edd7d45ad..0b4ad1ddba46 100644 --- a/examples/mt-pytorch/start_server.py +++ b/examples/app-pytorch/server.py @@ -1,7 +1,10 @@ from typing import List, Tuple -import flwr as fl -from flwr.common import Metrics +from flwr.server import ServerApp, ServerConfig +from flwr.server.strategy import FedAvg +from flwr.common import Metrics, ndarrays_to_parameters + +from task import Net, get_weights # Define metric aggregation function @@ -25,17 +28,38 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: } +# Initialize model parameters +ndarrays = get_weights(Net()) +parameters = ndarrays_to_parameters(ndarrays) + + # Define strategy -strategy = fl.server.strategy.FedAvg( +strategy = FedAvg( fraction_fit=1.0, # Select all available clients fraction_evaluate=0.0, # Disable evaluation min_available_clients=2, fit_metrics_aggregation_fn=weighted_average, + initial_parameters=parameters, ) -# Start Flower server -fl.server.start_server( - server_address="0.0.0.0:9092", - config=fl.server.ServerConfig(num_rounds=3), + +# Define config +config = ServerConfig(num_rounds=3) + + +# Flower ServerApp +app = ServerApp( + config=config, strategy=strategy, ) + + +# Legacy mode +if __name__ == "__main__": + from flwr.server import start_server + + start_server( + server_address="0.0.0.0:8080", + config=config, + strategy=strategy, + ) diff --git a/examples/app-pytorch/server_custom.py b/examples/app-pytorch/server_custom.py new file mode 100644 index 000000000000..0c2851e2afee --- /dev/null +++ b/examples/app-pytorch/server_custom.py @@ -0,0 +1,148 @@ +from typing import List, Tuple, Dict +import random +import time + +import flwr as fl +from flwr.common import ( + Context, + FitIns, + ndarrays_to_parameters, + parameters_to_ndarrays, + NDArrays, + Code, + Message, + MessageType, + Metrics, +) +from flwr.common.recordset_compat import fitins_to_recordset, recordset_to_fitres +from flwr.server import Driver, History +from flwr.server.strategy.aggregate import aggregate + +from task import Net, get_weights + + +# Define metric aggregation function +def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: + examples = [num_examples for num_examples, _ in metrics] + + # Multiply accuracy of each client by number of examples used + train_losses = [num_examples * m["train_loss"] for num_examples, m in metrics] + train_accuracies = [ + num_examples * m["train_accuracy"] for num_examples, m in metrics + ] + val_losses = [num_examples * m["val_loss"] for num_examples, m in metrics] + val_accuracies = [num_examples * m["val_accuracy"] for num_examples, m in metrics] + + # Aggregate and return custom metric (weighted average) + return { + "train_loss": sum(train_losses) / sum(examples), + "train_accuracy": sum(train_accuracies) / sum(examples), + "val_loss": sum(val_losses) / sum(examples), + "val_accuracy": sum(val_accuracies) / sum(examples), + } + + +# Run via `flower-server-app server:app` +app = fl.server.ServerApp() + + +@app.main() +def main(driver: Driver, context: Context) -> None: + """.""" + print("RUNNING!!!!!") + + num_client_nodes_per_round = 2 + sleep_time = 1 + num_rounds = 3 + parameters = ndarrays_to_parameters(get_weights(net=Net())) + + history = History() + for server_round in range(num_rounds): + print(f"Commencing server round {server_round + 1}") + + # List of sampled node IDs in this round + sampled_nodes: List[int] = [] + + # The Driver API might not immediately return enough client node IDs, so we + # loop and wait until enough client nodes are available. + while True: + all_node_ids = driver.get_node_ids() + + print(f"Got {len(all_node_ids)} client nodes: {all_node_ids}") + if len(all_node_ids) >= num_client_nodes_per_round: + # Sample client nodes + sampled_nodes = random.sample(all_node_ids, num_client_nodes_per_round) + break + time.sleep(3) + + # Log sampled node IDs + print(f"Sampled {len(sampled_nodes)} node IDs: {sampled_nodes}") + + # Schedule a task for all sampled nodes + fit_ins: FitIns = FitIns(parameters=parameters, config={}) + recordset = fitins_to_recordset(fitins=fit_ins, keep_input=True) + + messages = [] + for node_id in sampled_nodes: + message = driver.create_message( + content=recordset, + message_type=MessageType.TRAIN, + dst_node_id=node_id, + group_id=str(server_round), + ttl="", + ) + messages.append(message) + + message_ids = driver.push_messages(messages) + print(f"Pushed {len(message_ids)} messages: {message_ids}") + + # Wait for results, ignore empty message_ids + message_ids = [message_id for message_id in message_ids if message_id != ""] + + all_replies: List[Message] = [] + while True: + replies = driver.pull_messages(message_ids=message_ids) + print(f"Got {len(replies)} results") + all_replies += replies + if len(all_replies) == len(message_ids): + break + time.sleep(3) + + # Collect correct results + all_fitres = [ + recordset_to_fitres(msg.content, keep_input=True) for msg in all_replies + ] + print(f"Received {len(all_fitres)} results") + + weights_results: List[Tuple[NDArrays, int]] = [] + metrics_results: List[Tuple[int, Dict]] = [] + for fitres in all_fitres: + print(f"num_examples: {fitres.num_examples}, status: {fitres.status.code}") + + # Aggregate only if the status is OK + if fitres.status.code != Code.OK: + continue + weights_results.append( + (parameters_to_ndarrays(fitres.parameters), fitres.num_examples) + ) + metrics_results.append((fitres.num_examples, fitres.metrics)) + + # Aggregate parameters (FedAvg) + parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results)) + parameters = parameters_aggregated + + # Aggregate metrics + metrics_aggregated = weighted_average(metrics_results) + history.add_metrics_distributed_fit( + server_round=server_round, metrics=metrics_aggregated + ) + print("Round ", server_round, " metrics: ", metrics_aggregated) + + # Slow down the start of the next round + time.sleep(sleep_time) + + print("app_fit: losses_distributed %s", str(history.losses_distributed)) + print("app_fit: metrics_distributed_fit %s", str(history.metrics_distributed_fit)) + print("app_fit: metrics_distributed %s", str(history.metrics_distributed)) + print("app_fit: losses_centralized %s", str(history.losses_centralized)) + print("app_fit: metrics_centralized %s", str(history.metrics_centralized)) diff --git a/examples/app-pytorch/server_low_level.py b/examples/app-pytorch/server_low_level.py new file mode 100644 index 000000000000..560babac1b95 --- /dev/null +++ b/examples/app-pytorch/server_low_level.py @@ -0,0 +1,52 @@ +from typing import List, Tuple, Dict +import random +import time + +import flwr as fl +from flwr.common import Context, NDArrays, Message, MessageType, Metrics, RecordSet +from flwr.server import Driver + + +# Run via `flower-server-app server:app` +app = fl.server.ServerApp() + + +@app.main() +def main(driver: Driver, context: Context) -> None: + """This is a stub example that simply sends and receives messages.""" + print("Starting test run") + for server_round in range(3): + print(f"Commencing server round {server_round + 1}") + + # Get node IDs + node_ids = driver.get_node_ids() + + # Create messages + recordset = RecordSet() + messages = [] + for node_id in node_ids: + message = driver.create_message( + content=recordset, + message_type=MessageType.TRAIN, + dst_node_id=node_id, + group_id=str(server_round), + ttl="", + ) + messages.append(message) + + # Send messages + message_ids = driver.push_messages(messages) + print(f"Pushed {len(message_ids)} messages: {message_ids}") + + # Wait for results, ignore empty message_ids + message_ids = [message_id for message_id in message_ids if message_id != ""] + all_replies: List[Message] = [] + while True: + replies = driver.pull_messages(message_ids=message_ids) + print(f"Got {len(replies)} results") + all_replies += replies + if len(all_replies) == len(message_ids): + break + time.sleep(3) + + print(f"Received {len(all_replies)} results") diff --git a/examples/app-pytorch/server_workflow.py b/examples/app-pytorch/server_workflow.py new file mode 100644 index 000000000000..6923010ecf7b --- /dev/null +++ b/examples/app-pytorch/server_workflow.py @@ -0,0 +1,63 @@ +from typing import List, Tuple + +from task import Net, get_weights + +import flwr as fl +from flwr.common import Context, Metrics, ndarrays_to_parameters +from flwr.server import Driver, LegacyContext + + +# Define metric aggregation function +def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: + examples = [num_examples for num_examples, _ in metrics] + + # Multiply accuracy of each client by number of examples used + train_losses = [num_examples * m["train_loss"] for num_examples, m in metrics] + train_accuracies = [ + num_examples * m["train_accuracy"] for num_examples, m in metrics + ] + val_losses = [num_examples * m["val_loss"] for num_examples, m in metrics] + val_accuracies = [num_examples * m["val_accuracy"] for num_examples, m in metrics] + + # Aggregate and return custom metric (weighted average) + return { + "train_loss": sum(train_losses) / sum(examples), + "train_accuracy": sum(train_accuracies) / sum(examples), + "val_loss": sum(val_losses) / sum(examples), + "val_accuracy": sum(val_accuracies) / sum(examples), + } + + +# Initialize model parameters +ndarrays = get_weights(Net()) +parameters = ndarrays_to_parameters(ndarrays) + + +# Define strategy +strategy = fl.server.strategy.FedAvg( + fraction_fit=1.0, # Select all available clients + fraction_evaluate=0.0, # Disable evaluation + min_available_clients=2, + fit_metrics_aggregation_fn=weighted_average, + initial_parameters=parameters, +) + + +# Run via `flower-server-app server_workflow:app` +app = fl.server.ServerApp() + + +@app.main() +def main(driver: Driver, context: Context) -> None: + # Construct the LegacyContext + context = LegacyContext( + state=context.state, + config=fl.server.ServerConfig(num_rounds=3), + strategy=strategy, + ) + + # Create the workflow + workflow = fl.server.workflow.DefaultWorkflow() + + # Execute + workflow(driver, context) diff --git a/examples/app-pytorch/task.py b/examples/app-pytorch/task.py new file mode 100644 index 000000000000..240f290df320 --- /dev/null +++ b/examples/app-pytorch/task.py @@ -0,0 +1,94 @@ +from collections import OrderedDict +from logging import INFO + +import torch +import torch.nn as nn +import torch.nn.functional as F +from flwr.common.logger import log +from torch.utils.data import DataLoader +from torchvision.datasets import CIFAR10 +from torchvision.transforms import Compose, Normalize, ToTensor + + +DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +class Net(nn.Module): + """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" + + def __init__(self) -> None: + super(Net, self).__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = x.view(-1, 16 * 5 * 5) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + return self.fc3(x) + + +def load_data(): + """Load CIFAR-10 (training and test set).""" + trf = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + trainset = CIFAR10("./data", train=True, download=True, transform=trf) + testset = CIFAR10("./data", train=False, download=True, transform=trf) + return DataLoader(trainset, batch_size=32, shuffle=True), DataLoader(testset) + + +def train(net, trainloader, valloader, epochs, device): + """Train the model on the training set.""" + log(INFO, "Starting training...") + net.to(device) # move model to GPU if available + criterion = torch.nn.CrossEntropyLoss().to(device) + optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9) + net.train() + for _ in range(epochs): + for images, labels in trainloader: + images, labels = images.to(device), labels.to(device) + optimizer.zero_grad() + loss = criterion(net(images), labels) + loss.backward() + optimizer.step() + + train_loss, train_acc = test(net, trainloader) + val_loss, val_acc = test(net, valloader) + + results = { + "train_loss": train_loss, + "train_accuracy": train_acc, + "val_loss": val_loss, + "val_accuracy": val_acc, + } + return results + + +def test(net, testloader): + """Validate the model on the test set.""" + net.to(DEVICE) + criterion = torch.nn.CrossEntropyLoss() + correct, loss = 0, 0.0 + with torch.no_grad(): + for images, labels in testloader: + outputs = net(images.to(DEVICE)) + labels = labels.to(DEVICE) + loss += criterion(outputs, labels).item() + correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() + accuracy = correct / len(testloader.dataset) + return loss, accuracy + + +def get_weights(net): + return [val.cpu().numpy() for _, val in net.state_dict().items()] + + +def set_weights(net, parameters): + params_dict = zip(net.state_dict().keys(), parameters) + state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) + net.load_state_dict(state_dict, strict=True) diff --git a/examples/app-secure-aggregation/README.md b/examples/app-secure-aggregation/README.md new file mode 100644 index 000000000000..d1ea7bdc893f --- /dev/null +++ b/examples/app-secure-aggregation/README.md @@ -0,0 +1,93 @@ +# Secure aggregation with Flower (the SecAgg+ protocol) 🧪 + +> 🧪 = This example covers experimental features that might change in future versions of Flower +> Please consult the regular PyTorch code examples ([quickstart](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch), [advanced](https://github.com/adap/flower/tree/main/examples/advanced-pytorch)) to learn how to use Flower with PyTorch. + +The following steps describe how to use Secure Aggregation in flower, with `ClientApp` using `secaggplus_mod` and `ServerApp` using `SecAggPlusWorkflow`. + +## Preconditions + +Let's assume the following project structure: + +```bash +$ tree . +. +├── client.py # Client application using `secaggplus_mod` +├── server.py # Server application using `SecAggPlusWorkflow` +├── workflow_with_log.py # Augmented `SecAggPlusWorkflow` +├── run.sh # Quick start script +├── pyproject.toml # Project dependencies (poetry) +└── requirements.txt # Project dependencies (pip) +``` + +## Installing dependencies + +Project dependencies (such as and `flwr`) are defined in `pyproject.toml`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences. + +### Poetry + +```shell +poetry install +poetry shell +``` + +Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command: + +```shell +poetry run python3 -c "import flwr" +``` + +### pip + +Write the command below in your terminal to install the dependencies according to the configuration file requirements.txt. + +```shell +pip install -r requirements.txt +``` + +If you don't see any errors you're good to go! + +## Run the example with one command (recommended) + +```bash +./run.sh +``` + +## Run the example with the simulation engine + +```bash +flower-simulation --server-app server:app --client-app client:app --num-supernodes 5 +``` + +## Alternatively, run the example (in 7 terminal windows) + +Start the Flower Superlink in one terminal window: + +```bash +flower-superlink --insecure +``` + +Start 5 Flower `ClientApp` in 5 separate terminal windows: + +```bash +flower-client-app client:app --insecure +``` + +Start the Flower `ServerApp`: + +```bash +flower-server-app server:app --insecure --verbose +``` + +## Amend the example for practical usage + +For real-world applications, modify the `workflow` in `server.py` as follows: + +```python +workflow = fl.server.workflow.DefaultWorkflow( + fit_workflow=SecAggPlusWorkflow( + num_shares=, + reconstruction_threshold=, + ) +) +``` diff --git a/examples/app-secure-aggregation/client.py b/examples/app-secure-aggregation/client.py new file mode 100644 index 000000000000..b2fd02ec00d4 --- /dev/null +++ b/examples/app-secure-aggregation/client.py @@ -0,0 +1,34 @@ +import time + +from flwr.client import ClientApp, NumPyClient +from flwr.client.mod import secaggplus_mod +import numpy as np + + +# Define FlowerClient and client_fn +class FlowerClient(NumPyClient): + def fit(self, parameters, config): + # Instead of training and returning model parameters, + # the client directly returns [1.0, 1.0, 1.0] for demonstration purposes. + ret_vec = [np.ones(3)] + # Force a significant delay for testing purposes + if "drop" in config and config["drop"]: + print(f"Client dropped for testing purposes.") + time.sleep(8) + else: + print(f"Client uploading {ret_vec[0]}...") + return ret_vec, 1, {} + + +def client_fn(cid: str): + """Create and return an instance of Flower `Client`.""" + return FlowerClient().to_client() + + +# Flower ClientApp +app = ClientApp( + client_fn=client_fn, + mods=[ + secaggplus_mod, + ], +) diff --git a/examples/app-secure-aggregation/pyproject.toml b/examples/app-secure-aggregation/pyproject.toml new file mode 100644 index 000000000000..84b6502064c8 --- /dev/null +++ b/examples/app-secure-aggregation/pyproject.toml @@ -0,0 +1,14 @@ +[build-system] +requires = ["poetry-core>=1.4.0"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "app-secure-aggregation" +version = "0.1.0" +description = "Flower Secure Aggregation example." +authors = ["The Flower Authors "] + +[tool.poetry.dependencies] +python = "^3.8" +# Mandatory dependencies +flwr-nightly = { version = "1.8.0.dev20240309", extras = ["simulation"] } diff --git a/examples/app-secure-aggregation/requirements.txt b/examples/app-secure-aggregation/requirements.txt new file mode 100644 index 000000000000..5bac63a0d44c --- /dev/null +++ b/examples/app-secure-aggregation/requirements.txt @@ -0,0 +1 @@ +flwr-nightly[simulation]==1.8.0.dev20240309 diff --git a/examples/app-secure-aggregation/run.sh b/examples/app-secure-aggregation/run.sh new file mode 100755 index 000000000000..fa8dc47f26ef --- /dev/null +++ b/examples/app-secure-aggregation/run.sh @@ -0,0 +1,34 @@ +#!/bin/bash +# Kill any currently running client.py processes +pkill -f 'flower-client-app' + +# Kill any currently running flower-superlink processes +pkill -f 'flower-superlink' + +# Start the flower server +echo "Starting flower server in background..." +flower-superlink --insecure > /dev/null 2>&1 & +sleep 2 + +# Number of client processes to start +N=5 # Replace with your desired value + +echo "Starting $N ClientApps in background..." + +# Start N client processes +for i in $(seq 1 $N) +do + flower-client-app --insecure client:app > /dev/null 2>&1 & + sleep 0.1 +done + +echo "Starting ServerApp..." +flower-server-app --insecure server:app --verbose + +echo "Clearing background processes..." + +# Kill any currently running client.py processes +pkill -f 'flower-client-app' + +# Kill any currently running flower-superlink processes +pkill -f 'flower-superlink' diff --git a/examples/app-secure-aggregation/server.py b/examples/app-secure-aggregation/server.py new file mode 100644 index 000000000000..e9737a5a3c7f --- /dev/null +++ b/examples/app-secure-aggregation/server.py @@ -0,0 +1,45 @@ +from flwr.common import Context +from flwr.server import Driver, LegacyContext, ServerApp, ServerConfig +from flwr.server.strategy import FedAvg +from flwr.server.workflow import DefaultWorkflow, SecAggPlusWorkflow + +from workflow_with_log import SecAggPlusWorkflowWithLogs + + +# Define strategy +strategy = FedAvg( + fraction_fit=1.0, # Select all available clients + fraction_evaluate=0.0, # Disable evaluation + min_available_clients=5, +) + + +# Flower ServerApp +app = ServerApp() + + +@app.main() +def main(driver: Driver, context: Context) -> None: + # Construct the LegacyContext + context = LegacyContext( + state=context.state, + config=ServerConfig(num_rounds=3), + strategy=strategy, + ) + + # Create the workflow + workflow = DefaultWorkflow( + fit_workflow=SecAggPlusWorkflowWithLogs( + num_shares=3, + reconstruction_threshold=2, + timeout=5, + ) + # # For real-world applications, use the following code instead + # fit_workflow=SecAggPlusWorkflow( + # num_shares=, + # reconstruction_threshold=, + # ) + ) + + # Execute + workflow(driver, context) diff --git a/examples/app-secure-aggregation/workflow_with_log.py b/examples/app-secure-aggregation/workflow_with_log.py new file mode 100644 index 000000000000..a03ff8c13b6c --- /dev/null +++ b/examples/app-secure-aggregation/workflow_with_log.py @@ -0,0 +1,92 @@ +from flwr.common import Context, log, parameters_to_ndarrays +from logging import INFO +from flwr.server import Driver, LegacyContext +from flwr.server.workflow.secure_aggregation.secaggplus_workflow import ( + SecAggPlusWorkflow, + WorkflowState, +) +import numpy as np +from flwr.common.secure_aggregation.quantization import quantize +from flwr.server.workflow.constant import MAIN_PARAMS_RECORD +import flwr.common.recordset_compat as compat + + +class SecAggPlusWorkflowWithLogs(SecAggPlusWorkflow): + """The SecAggPlusWorkflow augmented for this example. + + This class includes additional logging and modifies one of the FitIns to instruct + the target client to simulate a dropout. + """ + + node_ids = [] + + def __call__(self, driver: Driver, context: Context) -> None: + _quantized = quantize( + [np.ones(3) for _ in range(5)], self.clipping_range, self.quantization_range + ) + log(INFO, "") + log( + INFO, + "################################ Introduction ################################", + ) + log( + INFO, + "In the example, each client will upload a vector [1.0, 1.0, 1.0] instead of", + ) + log(INFO, "model updates for demonstration purposes.") + log( + INFO, + "Client 0 is configured to drop out before uploading the masked vector.", + ) + log(INFO, "After quantization, the raw vectors will look like:") + for i in range(1, 5): + log(INFO, "\t%s from Client %s", _quantized[i], i) + log( + INFO, + "Numbers are rounded to integers stochastically during the quantization", + ) + log(INFO, ", and thus entries may not be identical.") + log( + INFO, + "The above raw vectors are hidden from the driver through adding masks.", + ) + log(INFO, "") + log( + INFO, + "########################## Secure Aggregation Start ##########################", + ) + + super().__call__(driver, context) + + paramsrecord = context.state.parameters_records[MAIN_PARAMS_RECORD] + parameters = compat.parametersrecord_to_parameters(paramsrecord, True) + ndarrays = parameters_to_ndarrays(parameters) + log( + INFO, + "Weighted average of vectors (dequantized): %s", + ndarrays[0], + ) + log( + INFO, + "########################### Secure Aggregation End ###########################", + ) + log(INFO, "") + + def setup_stage( + self, driver: Driver, context: LegacyContext, state: WorkflowState + ) -> bool: + ret = super().setup_stage(driver, context, state) + self.node_ids = list(state.active_node_ids) + state.nid_to_fitins[self.node_ids[0]].configs_records["fitins.config"][ + "drop" + ] = True + return ret + + def collect_masked_vectors_stage( + self, driver: Driver, context: LegacyContext, state: WorkflowState + ) -> bool: + ret = super().collect_masked_vectors_stage(driver, context, state) + for node_id in state.sampled_node_ids - state.active_node_ids: + log(INFO, "Client %s dropped out.", self.node_ids.index(node_id)) + log(INFO, "Obtained sum of masked vectors: %s", state.aggregate_ndarrays[1]) + return ret diff --git a/examples/custom-metrics/README.md b/examples/custom-metrics/README.md index debcd7919839..317fb6336106 100644 --- a/examples/custom-metrics/README.md +++ b/examples/custom-metrics/README.md @@ -9,7 +9,7 @@ The main takeaways of this implementation are: - the use of the `output_dict` on the client side - inside `evaluate` method on `client.py` - the use of the `evaluate_metrics_aggregation_fn` - to aggregate the metrics on the server side, part of the `strategy` on `server.py` -This example is based on the `quickstart-tensorflow` with CIFAR-10, source [here](https://flower.dev/docs/quickstart-tensorflow.html), with the addition of [Flower Datasets](https://flower.dev/docs/datasets/index.html) to retrieve the CIFAR-10. +This example is based on the `quickstart-tensorflow` with CIFAR-10, source [here](https://flower.ai/docs/quickstart-tensorflow.html), with the addition of [Flower Datasets](https://flower.ai/docs/datasets/index.html) to retrieve the CIFAR-10. Using the CIFAR-10 dataset for classification, this is a multi-class classification problem, thus some changes on how to calculate the metrics using `average='micro'` and `np.argmax` is required. For binary classification, this is not required. Also, for unsupervised learning tasks, such as using a deep autoencoder, a custom metric based on reconstruction error could be implemented on client side. @@ -91,16 +91,16 @@ chmod +x run.sh ./run.sh ``` -You will see that Keras is starting a federated training. Have a look to the [Flower Quickstarter documentation](https://flower.dev/docs/quickstart-tensorflow.html) for a detailed explanation. You can add `steps_per_epoch=3` to `model.fit()` if you just want to evaluate that everything works without having to wait for the client-side training to finish (this will save you a lot of time during development). +You will see that Keras is starting a federated training. Have a look to the [Flower Quickstarter documentation](https://flower.ai/docs/quickstart-tensorflow.html) for a detailed explanation. You can add `steps_per_epoch=3` to `model.fit()` if you just want to evaluate that everything works without having to wait for the client-side training to finish (this will save you a lot of time during development). Running `run.sh` will result in the following output (after 3 rounds): ```shell INFO flwr 2024-01-17 17:45:23,794 | app.py:228 | app_fit: metrics_distributed { - 'accuracy': [(1, 0.10000000149011612), (2, 0.10000000149011612), (3, 0.3393000066280365)], - 'acc': [(1, 0.1), (2, 0.1), (3, 0.3393)], - 'rec': [(1, 0.1), (2, 0.1), (3, 0.3393)], - 'prec': [(1, 0.1), (2, 0.1), (3, 0.3393)], + 'accuracy': [(1, 0.10000000149011612), (2, 0.10000000149011612), (3, 0.3393000066280365)], + 'acc': [(1, 0.1), (2, 0.1), (3, 0.3393)], + 'rec': [(1, 0.1), (2, 0.1), (3, 0.3393)], + 'prec': [(1, 0.1), (2, 0.1), (3, 0.3393)], 'f1': [(1, 0.10000000000000002), (2, 0.10000000000000002), (3, 0.3393)] } ``` diff --git a/examples/custom-metrics/client.py b/examples/custom-metrics/client.py index b2206118ed44..d0230e455477 100644 --- a/examples/custom-metrics/client.py +++ b/examples/custom-metrics/client.py @@ -68,4 +68,6 @@ def evaluate(self, parameters, config): # Start Flower client -fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=FlowerClient()) +fl.client.start_client( + server_address="127.0.0.1:8080", client=FlowerClient().to_client() +) diff --git a/examples/custom-metrics/pyproject.toml b/examples/custom-metrics/pyproject.toml index 8a2da6562018..51c29e213d81 100644 --- a/examples/custom-metrics/pyproject.toml +++ b/examples/custom-metrics/pyproject.toml @@ -7,8 +7,8 @@ name = "custom-metrics" version = "0.1.0" description = "Federated Learning with Flower and Custom Metrics" authors = [ - "The Flower Authors ", - "Gustavo Bertoli " + "The Flower Authors ", + "Gustavo Bertoli ", ] [tool.poetry.dependencies] @@ -16,4 +16,4 @@ python = ">=3.8,<3.11" flwr = ">=1.0,<2.0" flwr-datasets = { version = "*", extras = ["vision"] } scikit-learn = "^1.2.2" -tensorflow = "==2.12.0" \ No newline at end of file +tensorflow = "==2.12.0" diff --git a/examples/custom-mods/.gitignore b/examples/custom-mods/.gitignore new file mode 100644 index 000000000000..260d28a67c6f --- /dev/null +++ b/examples/custom-mods/.gitignore @@ -0,0 +1,2 @@ +wandb/ +.runs_history/ diff --git a/examples/custom-mods/README.md b/examples/custom-mods/README.md new file mode 100644 index 000000000000..b0ad668c2dec --- /dev/null +++ b/examples/custom-mods/README.md @@ -0,0 +1,339 @@ +# Using custom mods 🧪 + +> 🧪 = This example covers experimental features that might change in future versions of Flower +> Please consult the regular PyTorch code examples ([quickstart](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch), [advanced](https://github.com/adap/flower/tree/main/examples/advanced-pytorch)) to learn how to use Flower with PyTorch. + +The following steps describe how to write custom Flower Mods and use them in a simple example. + +## Writing custom Flower Mods + +### Flower Mods basics + +As described [here](https://flower.ai/docs/framework/how-to-use-built-in-mods.html#what-are-mods), Flower Mods in their simplest form can be described as: + +```python +def basic_mod(msg: Message, context: Context, app: ClientApp) -> Message: + # Do something with incoming Message (or Context) + # before passing to the inner ``ClientApp`` + reply = app(msg, context) + # Do something with outgoing Message (or Context) + # before returning + return reply +``` + +and used when defining the `ClientApp`: + +```python +app = fl.client.ClientApp( + client_fn=client_fn, + mods=[basic_mod], +) +``` + +Note that in this specific case, this mod won't modify anything, and perform FL as usual. + +### WandB Flower Mod + +If we want to write a mod to monitor our client-side training using [Weights & Biases](https://github.com/wandb/wandb), we can follow the steps below. + +First, we need to initialize our W&B project with the correct parameters: + +```python +wandb.init( + project=..., + group=..., + name=..., + id=..., + resume="allow", + reinit=True, +) +``` + +In our case, the group should be the `run_id`, specific to a `ServerApp` run, and the `name` should be the `node_id`. This will make it easy to navigate our W&B project, as for each run we will be able to see the computed results as a whole or for each individual client. + +The `id` needs to be unique, so it will be a combination of `run_id` and `node_id`. + +In the end we have: + +```python +def wandb_mod(msg: Message, context: Context, app: ClientAppCallable) -> Message: + run_id = msg.metadata.run_id + group_name = f"Run ID: {run_id}" + + node_id = str(msg.metadata.dst_node_id) + run_name = f"Node ID: {node_id}" + + wandb.init( + project="Mod Name", + group=group_name, + name=run_name, + id=f"{run_id}_{node_id}", + resume="allow", + reinit=True, + ) +``` + +Now, before the message is processed by the server, we will store the starting time and the round number, in order to compute the time it took the client to perform its fit step. + +```python +server_round = int(msg.metadata.group_id) +start_time = time.time() +``` + +And then, we can send the message to the client: + +```python +reply = app(msg, context) +``` + +And now, with the message we got back, we can gather our metrics: + +```python +if reply.metadata.message_type == MessageType.TRAIN and reply.has_content(): + + time_diff = time.time() - start_time + + metrics = reply.content.configs_records + + results_to_log = dict(metrics.get("fitres.metrics", ConfigsRecord())) + results_to_log["fit_time"] = time_diff +``` + +Note that we store our metrics in the `results_to_log` variable and that we only initialize this variable when our client is sending back fit results (with content in it). + +Finally, we can send our results to W&B using: + +```python +wandb.log(results_to_log, step=int(server_round), commit=True) +``` + +The complete mod becomes: + +```python +def wandb_mod(msg: Message, context: Context, app: ClientAppCallable) -> Message: + server_round = int(msg.metadata.group_id) + + if reply.metadata.message_type == MessageType.TRAIN and server_round == 1: + run_id = msg.metadata.run_id + group_name = f"Run ID: {run_id}" + + node_id = str(msg.metadata.dst_node_id) + run_name = f"Node ID: {node_id}" + + wandb.init( + project="Mod Name", + group=group_name, + name=run_name, + id=f"{run_id}_{node_id}", + resume="allow", + reinit=True, + ) + + start_time = time.time() + + reply = app(msg, context) + + if reply.metadata.message_type == MessageType.TRAIN and reply.has_content(): + + time_diff = time.time() - start_time + + metrics = reply.content.configs_records + + results_to_log = dict(metrics.get("fitres.metrics", ConfigsRecord())) + + results_to_log["fit_time"] = time_diff + + wandb.log(results_to_log, step=int(server_round), commit=True) + + return reply +``` + +And it can be used like: + +```python +app = fl.client.ClientApp( + client_fn=client_fn, + mods=[wandb_mod], +) +``` + +If we want to pass an argument to our mod, we can use a wrapper function: + +```python +def get_wandb_mod(name: str) -> Mod: + def wandb_mod(msg: Message, context: Context, app: ClientAppCallable) -> Message: + server_round = int(msg.metadata.group_id) + + run_id = msg.metadata.run_id + group_name = f"Run ID: {run_id}" + + node_id = str(msg.metadata.dst_node_id) + run_name = f"Node ID: {node_id}" + + wandb.init( + project=name, + group=group_name, + name=run_name, + id=f"{run_id}_{node_id}", + resume="allow", + reinit=True, + ) + + start_time = time.time() + + reply = app(msg, context) + + if reply.metadata.message_type == MessageType.TRAIN and reply.has_content(): + + time_diff = time.time() - start_time + + metrics = reply.content.configs_records + + results_to_log = dict(metrics.get("fitres.metrics", ConfigsRecord())) + + results_to_log["fit_time"] = time_diff + + wandb.log(results_to_log, step=int(server_round), commit=True) + + return reply + + return wandb_mod +``` + +And use it like: + +```python +app = fl.client.ClientApp( + client_fn=client_fn, + mods=[ + get_wandb_mod("Custom mods example"), + ], +) +``` + +### TensorBoard Flower Mod + +The [TensorBoard](https://www.tensorflow.org/tensorboard) Mod will only differ in the initialization and how the data is sent to TensorBoard: + +```python +def get_tensorboard_mod(logdir) -> Mod: + os.makedirs(logdir, exist_ok=True) + + def tensorboard_mod( + msg: Message, context: Context, app: ClientAppCallable + ) -> Message: + logdir_run = os.path.join(logdir, str(msg.metadata.run_id)) + + node_id = str(msg.metadata.dst_node_id) + + server_round = int(msg.metadata.group_id) + + start_time = time.time() + + reply = app(msg, context) + + time_diff = time.time() - start_time + + if reply.metadata.message_type == MessageType.TRAIN and reply.has_content(): + writer = tf.summary.create_file_writer(os.path.join(logdir_run, node_id)) + + metrics = dict( + reply.content.configs_records.get("fitres.metrics", ConfigsRecord()) + ) + + with writer.as_default(step=server_round): + tf.summary.scalar(f"fit_time", time_diff, step=server_round) + for metric in metrics: + tf.summary.scalar( + f"{metric}", + metrics[metric], + step=server_round, + ) + writer.flush() + + return reply + + return tensorboard_mod +``` + +For the initialization, TensorBoard uses a custom directory path, which can, in this case, be passed as an argument to the wrapper function. + +It can be used in the following way: + +```python +app = fl.client.ClientApp( + client_fn=client_fn, + mods=[get_tensorboard_mod(".runs_history/")], +) +``` + +## Running the example + +### Preconditions + +Let's assume the following project structure: + +```bash +$ tree . +. +├── client.py # <-- contains `ClientApp` +├── server.py # <-- contains `ServerApp` +├── task.py # <-- task-specific code (model, data) +└── requirements.txt # <-- dependencies +``` + +### Install dependencies + +```bash +pip install -r requirements.txt +``` + +For [W&B](wandb.ai) you will also need a valid account. + +### Start the long-running Flower server (SuperLink) + +```bash +flower-superlink --insecure +``` + +### Start the long-running Flower client (SuperNode) + +In a new terminal window, start the first long-running Flower client using: + +```bash +flower-client-app client:wandb_app --insecure +``` + +for W&B monitoring, or: + +```bash +flower-client-app client:tb_app --insecure +``` + +for TensorBoard. + +In yet another new terminal window, start the second long-running Flower client (with the mod of your choice): + +```bash +flower-client-app client:{wandb,tb}_app --insecure +``` + +### Run the Flower App + +With both the long-running server (SuperLink) and two clients (SuperNode) up and running, we can now run the actual Flower App: + +```bash +flower-server-app server:app --insecure +``` + +### Check the results + +For W&B, you will need to login to the [website](wandb.ai). + +For TensorBoard, you will need to run the following command in your terminal: + +```sh +tensorboard --logdir +``` + +Where `` needs to be replaced by the directory passed as an argument to the wrapper function (`.runs_history/` by default). diff --git a/examples/custom-mods/client.py b/examples/custom-mods/client.py new file mode 100644 index 000000000000..2b87a24da19d --- /dev/null +++ b/examples/custom-mods/client.py @@ -0,0 +1,160 @@ +import logging +import os +import time + +import flwr as fl +import tensorflow as tf +import wandb +from flwr.common import ConfigsRecord +from flwr.client.typing import ClientAppCallable, Mod +from flwr.common.context import Context +from flwr.common.message import Message +from flwr.common.constant import MessageType + +from task import ( + Net, + DEVICE, + load_data, + get_parameters, + set_parameters, + train, + test, +) + + +class WBLoggingFilter(logging.Filter): + def filter(self, record): + return ( + "login" in record.getMessage() + or "View project at" in record.getMessage() + or "View run at" in record.getMessage() + ) + + +# Load model and data (simple CNN, CIFAR-10) +net = Net().to(DEVICE) +trainloader, testloader = load_data() + + +# Define Flower client +class FlowerClient(fl.client.NumPyClient): + def get_parameters(self, config): + return get_parameters(net) + + def fit(self, parameters, config): + set_parameters(net, parameters) + results = train(net, trainloader, testloader, epochs=1, device=DEVICE) + return get_parameters(net), len(trainloader.dataset), results + + def evaluate(self, parameters, config): + set_parameters(net, parameters) + loss, accuracy = test(net, testloader) + return loss, len(testloader.dataset), {"accuracy": accuracy} + + +def client_fn(cid: str): + return FlowerClient().to_client() + + +def get_wandb_mod(name: str) -> Mod: + def wandb_mod(msg: Message, context: Context, app: ClientAppCallable) -> Message: + """Flower Mod that logs the metrics dictionary returned by the client's fit + function to Weights & Biases.""" + server_round = int(msg.metadata.group_id) + + if server_round == 1 and msg.metadata.message_type == MessageType.TRAIN: + run_id = msg.metadata.run_id + group_name = f"Run ID: {run_id}" + + node_id = str(msg.metadata.dst_node_id) + run_name = f"Node ID: {node_id}" + + wandb.init( + project=name, + group=group_name, + name=run_name, + id=f"{run_id}_{node_id}", + resume="allow", + reinit=True, + ) + + start_time = time.time() + + reply = app(msg, context) + + time_diff = time.time() - start_time + + # if the `ClientApp` just processed a "fit" message, let's log some metrics to W&B + if reply.metadata.message_type == MessageType.TRAIN and reply.has_content(): + + metrics = reply.content.configs_records + + results_to_log = dict(metrics.get("fitres.metrics", ConfigsRecord())) + + results_to_log["fit_time"] = time_diff + + wandb.log(results_to_log, step=int(server_round), commit=True) + + return reply + + return wandb_mod + + +def get_tensorboard_mod(logdir) -> Mod: + os.makedirs(logdir, exist_ok=True) + + def tensorboard_mod( + msg: Message, context: Context, app: ClientAppCallable + ) -> Message: + """Flower Mod that logs the metrics dictionary returned by the client's fit + function to TensorBoard.""" + logdir_run = os.path.join(logdir, str(msg.metadata.run_id)) + + node_id = str(msg.metadata.dst_node_id) + + server_round = int(msg.metadata.group_id) + + start_time = time.time() + + reply = app(msg, context) + + time_diff = time.time() - start_time + + # if the `ClientApp` just processed a "fit" message, let's log some metrics to TensorBoard + if reply.metadata.message_type == MessageType.TRAIN and reply.has_content(): + writer = tf.summary.create_file_writer(os.path.join(logdir_run, node_id)) + + metrics = dict( + reply.content.configs_records.get("fitres.metrics", ConfigsRecord()) + ) + + with writer.as_default(step=server_round): + tf.summary.scalar(f"fit_time", time_diff, step=server_round) + for metric in metrics: + tf.summary.scalar( + f"{metric}", + metrics[metric], + step=server_round, + ) + writer.flush() + + return reply + + return tensorboard_mod + + +# Run via `flower-client-app client:wandb_app` +wandb_app = fl.client.ClientApp( + client_fn=client_fn, + mods=[ + get_wandb_mod("Custom mods example"), + ], +) + +# Run via `flower-client-app client:tb_app` +tb_app = fl.client.ClientApp( + client_fn=client_fn, + mods=[ + get_tensorboard_mod(".runs_history/"), + ], +) diff --git a/examples/mt-pytorch/pyproject.toml b/examples/custom-mods/pyproject.toml similarity index 62% rename from examples/mt-pytorch/pyproject.toml rename to examples/custom-mods/pyproject.toml index 4978035495ea..e690e05bab8f 100644 --- a/examples/mt-pytorch/pyproject.toml +++ b/examples/custom-mods/pyproject.toml @@ -3,14 +3,16 @@ requires = ["poetry-core>=1.4.0"] build-backend = "poetry.core.masonry.api" [tool.poetry] -name = "mt-pytorch" +name = "app-pytorch" version = "0.1.0" description = "Multi-Tenant Federated Learning with Flower and PyTorch" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" -flwr-nightly = {version = ">=1.0,<2.0", extras = ["rest", "simulation"]} +flwr = { path = "../../", develop = true, extras = ["simulation"] } +tensorboard = "2.16.2" torch = "1.13.1" torchvision = "0.14.1" tqdm = "4.65.0" +wandb = "0.16.3" diff --git a/examples/mt-pytorch/requirements.txt b/examples/custom-mods/requirements.txt similarity index 72% rename from examples/mt-pytorch/requirements.txt rename to examples/custom-mods/requirements.txt index ae0a65386f2b..75b2c1135f11 100644 --- a/examples/mt-pytorch/requirements.txt +++ b/examples/custom-mods/requirements.txt @@ -1,4 +1,6 @@ flwr-nightly[rest,simulation]>=1.0, <2.0 +tensorboard==2.16.2 torch==1.13.1 torchvision==0.14.1 tqdm==4.65.0 +wandb==0.16.3 diff --git a/examples/mt-pytorch/start_driver.py b/examples/custom-mods/server.py similarity index 68% rename from examples/mt-pytorch/start_driver.py rename to examples/custom-mods/server.py index 3241a548950a..c2d8a4fe5ee7 100644 --- a/examples/mt-pytorch/start_driver.py +++ b/examples/custom-mods/server.py @@ -9,12 +9,16 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: examples = [num_examples for num_examples, _ in metrics] # Multiply accuracy of each client by number of examples used - train_losses = [num_examples * m["train_loss"] for num_examples, m in metrics] + train_losses = [ + num_examples * float(m["train_loss"]) for num_examples, m in metrics + ] train_accuracies = [ - num_examples * m["train_accuracy"] for num_examples, m in metrics + num_examples * float(m["train_accuracy"]) for num_examples, m in metrics + ] + val_losses = [num_examples * float(m["val_loss"]) for num_examples, m in metrics] + val_accuracies = [ + num_examples * float(m["val_accuracy"]) for num_examples, m in metrics ] - val_losses = [num_examples * m["val_loss"] for num_examples, m in metrics] - val_accuracies = [num_examples * m["val_accuracy"] for num_examples, m in metrics] # Aggregate and return custom metric (weighted average) return { @@ -33,9 +37,9 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: fit_metrics_aggregation_fn=weighted_average, ) -# Start Flower server -fl.driver.start_driver( - server_address="0.0.0.0:9091", + +# Run via `flower-server-app server:app` +app = fl.server.ServerApp( config=fl.server.ServerConfig(num_rounds=3), strategy=strategy, ) diff --git a/examples/mt-pytorch/task.py b/examples/custom-mods/task.py similarity index 100% rename from examples/mt-pytorch/task.py rename to examples/custom-mods/task.py diff --git a/examples/doc/source/_templates/base.html b/examples/doc/source/_templates/base.html index e4fe80720b74..08030fb08c15 100644 --- a/examples/doc/source/_templates/base.html +++ b/examples/doc/source/_templates/base.html @@ -5,7 +5,7 @@ - + {%- if metatags %}{{ metatags }}{% endif -%} @@ -99,6 +99,6 @@ {%- endblock -%} {%- endblock scripts -%} - + diff --git a/examples/doc/source/conf.py b/examples/doc/source/conf.py index 3d629c39c7ea..bf177aa5ae24 100644 --- a/examples/doc/source/conf.py +++ b/examples/doc/source/conf.py @@ -30,7 +30,7 @@ author = "The Flower Authors" # The full version, including alpha/beta/rc tags -release = "1.7.0" +release = "1.8.0" # -- General configuration --------------------------------------------------- @@ -76,7 +76,7 @@ html_title = f"Flower Examples {release}" html_logo = "_static/flower-logo.png" html_favicon = "_static/favicon.ico" -html_baseurl = "https://flower.dev/docs/examples/" +html_baseurl = "https://flower.ai/docs/examples/" html_theme_options = { # diff --git a/examples/dp-sgd-mnist/pyproject.toml b/examples/dp-sgd-mnist/pyproject.toml index 3158ab3d52f6..161952fd2aa4 100644 --- a/examples/dp-sgd-mnist/pyproject.toml +++ b/examples/dp-sgd-mnist/pyproject.toml @@ -7,14 +7,14 @@ name = "dp-sgd-mnist" version = "0.1.0" description = "Federated training with Tensorflow Privacy" authors = [ - "The Flower Authors ", - "Kaushik Amar Das " + "The Flower Authors ", + "Kaushik Amar Das ", ] [tool.poetry.dependencies] python = ">=3.8,<3.11" # flwr = { path = "../../", develop = true } # Development flwr = ">=1.0,<2.0" -tensorflow-cpu = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\""} -tensorflow-macos = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\""} +tensorflow-cpu = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\"" } +tensorflow-macos = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\"" } tensorflow-privacy = "0.8.10" diff --git a/examples/embedded-devices/Dockerfile b/examples/embedded-devices/Dockerfile index ea63839bc9d6..a85c05c4bb7a 100644 --- a/examples/embedded-devices/Dockerfile +++ b/examples/embedded-devices/Dockerfile @@ -8,6 +8,7 @@ RUN pip3 install --upgrade pip # Install flower RUN pip3 install flwr>=1.0 +RUN pip3 install flwr-datsets>=0.2 RUN pip3 install tqdm==4.65.0 WORKDIR /client diff --git a/examples/embedded-devices/README.md b/examples/embedded-devices/README.md index 4c79eafbbf84..f1c5931b823a 100644 --- a/examples/embedded-devices/README.md +++ b/examples/embedded-devices/README.md @@ -192,7 +192,8 @@ On the machine of your choice, launch the server: # Launch your server. # Will wait for at least 2 clients to be connected, then will train for 3 FL rounds # The command below will sample all clients connected (since sample_fraction=1.0) -python server.py --rounds 3 --min_num_clients 2 --sample_fraction 1.0 # append `--mnist` if you want to use that dataset/model setting +# The server is dataset agnostic (use the same command for MNIST and CIFAR10) +python server.py --rounds 3 --min_num_clients 2 --sample_fraction 1.0 ``` > If you are on macOS with Apple Silicon (i.e. M1, M2 chips), you might encounter a `grpcio`-related issue when launching your server. If you are in a conda environment you can solve this easily by doing: `pip uninstall grpcio` and then `conda install grpcio`. diff --git a/examples/embedded-devices/client_pytorch.py b/examples/embedded-devices/client_pytorch.py index 5d236c9e9389..3f1e6c7d51b7 100644 --- a/examples/embedded-devices/client_pytorch.py +++ b/examples/embedded-devices/client_pytorch.py @@ -6,18 +6,19 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch.utils.data import DataLoader, random_split -from torchvision.datasets import CIFAR10, MNIST +from torch.utils.data import DataLoader from torchvision.transforms import Compose, Normalize, ToTensor from torchvision.models import mobilenet_v3_small from tqdm import tqdm +from flwr_datasets import FederatedDataset + parser = argparse.ArgumentParser(description="Flower Embedded devices") parser.add_argument( "--server_address", type=str, default="0.0.0.0:8080", - help=f"gRPC server address (deafault '0.0.0.0:8080')", + help=f"gRPC server address (default '0.0.0.0:8080')", ) parser.add_argument( "--cid", @@ -28,25 +29,13 @@ parser.add_argument( "--mnist", action="store_true", - help="If you use Raspberry Pi Zero clients (which just have 512MB or RAM) use MNIST", + help="If you use Raspberry Pi Zero clients (which just have 512MB or RAM) use " + "MNIST", ) - warnings.filterwarnings("ignore", category=UserWarning) NUM_CLIENTS = 50 -# a config for mobilenetv2 that works for -# small input sizes (i.e. 32x32 as in CIFAR) -mb2_cfg = [ - (1, 16, 1, 1), - (6, 24, 2, 1), - (6, 32, 3, 2), - (6, 64, 4, 2), - (6, 96, 3, 1), - (6, 160, 3, 2), - (6, 320, 1, 1), -] - class Net(nn.Module): """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz').""" @@ -73,7 +62,9 @@ def train(net, trainloader, optimizer, epochs, device): """Train the model on the training set.""" criterion = torch.nn.CrossEntropyLoss() for _ in range(epochs): - for images, labels in tqdm(trainloader): + for batch in tqdm(trainloader): + batch = list(batch.values()) + images, labels = batch[0], batch[1] optimizer.zero_grad() criterion(net(images.to(device)), labels.to(device)).backward() optimizer.step() @@ -84,7 +75,9 @@ def test(net, testloader, device): criterion = torch.nn.CrossEntropyLoss() correct, loss = 0, 0.0 with torch.no_grad(): - for images, labels in tqdm(testloader): + for batch in tqdm(testloader): + batch = list(batch.values()) + images, labels = batch[0], batch[1] outputs = net(images.to(device)) labels = labels.to(device) loss += criterion(outputs, labels).item() @@ -95,44 +88,33 @@ def test(net, testloader, device): def prepare_dataset(use_mnist: bool): """Get MNIST/CIFAR-10 and return client partitions and global testset.""" - dataset = MNIST if use_mnist else CIFAR10 if use_mnist: + fds = FederatedDataset(dataset="mnist", partitioners={"train": NUM_CLIENTS}) + img_key = "image" norm = Normalize((0.1307,), (0.3081,)) else: + fds = FederatedDataset(dataset="cifar10", partitioners={"train": NUM_CLIENTS}) + img_key = "img" norm = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) - - trf = Compose([ToTensor(), norm]) - trainset = dataset("./data", train=True, download=True, transform=trf) - testset = dataset("./data", train=False, download=True, transform=trf) - - print("Partitioning dataset (IID)...") - - # Split trainset into `num_partitions` trainsets - num_images = len(trainset) // NUM_CLIENTS - partition_len = [num_images] * NUM_CLIENTS - - trainsets = random_split( - trainset, partition_len, torch.Generator().manual_seed(2023) - ) - - val_ratio = 0.1 - - # Create dataloaders with train+val support - train_partitions = [] - val_partitions = [] - for trainset_ in trainsets: - num_total = len(trainset_) - num_val = int(val_ratio * num_total) - num_train = num_total - num_val - - for_train, for_val = random_split( - trainset_, [num_train, num_val], torch.Generator().manual_seed(2023) - ) - - train_partitions.append(for_train) - val_partitions.append(for_val) - - return train_partitions, val_partitions, testset + pytorch_transforms = Compose([ToTensor(), norm]) + + def apply_transforms(batch): + """Apply transforms to the partition from FederatedDataset.""" + batch[img_key] = [pytorch_transforms(img) for img in batch[img_key]] + return batch + + trainsets = [] + validsets = [] + for partition_id in range(NUM_CLIENTS): + partition = fds.load_partition(partition_id, "train") + # Divide data on each node: 90% train, 10% test + partition = partition.train_test_split(test_size=0.1) + partition = partition.with_transform(apply_transforms) + trainsets.append(partition["train"]) + validsets.append(partition["test"]) + testset = fds.load_full("test") + testset = testset.with_transform(apply_transforms) + return trainsets, validsets, testset # Flower client, adapted from Pytorch quickstart/simulation example @@ -148,8 +130,6 @@ def __init__(self, trainset, valset, use_mnist): self.model = Net() else: self.model = mobilenet_v3_small(num_classes=10) - # let's not reduce spatial resolution too early - self.model.features[0][0].stride = (1, 1) # Determine device self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.model.to(self.device) # send model to device @@ -200,15 +180,15 @@ def main(): assert args.cid < NUM_CLIENTS use_mnist = args.mnist - # Download CIFAR-10 dataset and partition it + # Download dataset and partition it trainsets, valsets, _ = prepare_dataset(use_mnist) # Start Flower client setting its associated data partition - fl.client.start_numpy_client( + fl.client.start_client( server_address=args.server_address, client=FlowerClient( trainset=trainsets[args.cid], valset=valsets[args.cid], use_mnist=use_mnist - ), + ).to_client(), ) diff --git a/examples/embedded-devices/client_tf.py b/examples/embedded-devices/client_tf.py index 3457af1c7a66..d59b31ab1569 100644 --- a/examples/embedded-devices/client_tf.py +++ b/examples/embedded-devices/client_tf.py @@ -6,6 +6,8 @@ import tensorflow as tf from tensorflow import keras as keras +from flwr_datasets import FederatedDataset + parser = argparse.ArgumentParser(description="Flower Embedded devices") parser.add_argument( "--server_address", @@ -32,30 +34,28 @@ def prepare_dataset(use_mnist: bool): """Download and partitions the CIFAR-10/MNIST dataset.""" if use_mnist: - (x_train, y_train), testset = tf.keras.datasets.mnist.load_data() + fds = FederatedDataset(dataset="mnist", partitioners={"train": NUM_CLIENTS}) + img_key = "image" else: - (x_train, y_train), testset = tf.keras.datasets.cifar10.load_data() + fds = FederatedDataset(dataset="cifar10", partitioners={"train": NUM_CLIENTS}) + img_key = "img" partitions = [] - # We keep all partitions equal-sized in this example - partition_size = math.floor(len(x_train) / NUM_CLIENTS) - for cid in range(NUM_CLIENTS): - # Split dataset into non-overlapping NUM_CLIENT partitions - idx_from, idx_to = int(cid) * partition_size, (int(cid) + 1) * partition_size - - x_train_cid, y_train_cid = ( - x_train[idx_from:idx_to] / 255.0, - y_train[idx_from:idx_to], + for partition_id in range(NUM_CLIENTS): + partition = fds.load_partition(partition_id, "train") + partition.set_format("numpy") + # Divide data on each node: 90% train, 10% test + partition = partition.train_test_split(test_size=0.1) + x_train, y_train = ( + partition["train"][img_key] / 255.0, + partition["train"]["label"], ) - - # now partition into train/validation - # Use 10% of the client's training data for validation - split_idx = math.floor(len(x_train_cid) * 0.9) - - client_train = (x_train_cid[:split_idx], y_train_cid[:split_idx]) - client_val = (x_train_cid[split_idx:], y_train_cid[split_idx:]) - partitions.append((client_train, client_val)) - - return partitions, testset + x_test, y_test = partition["test"][img_key] / 255.0, partition["test"]["label"] + partitions.append(((x_train, y_train), (x_test, y_test))) + data_centralized = fds.load_full("test") + data_centralized.set_format("numpy") + x_centralized = data_centralized[img_key] / 255.0 + y_centralized = data_centralized["label"] + return partitions, (x_centralized, y_centralized) class FlowerClient(fl.client.NumPyClient): @@ -68,7 +68,7 @@ def __init__(self, trainset, valset, use_mnist: bool): # Instantiate model if use_mnist: # small model for MNIST - self.model = model = keras.Sequential( + self.model = keras.Sequential( [ keras.Input(shape=(28, 28, 1)), keras.layers.Conv2D(32, kernel_size=(5, 5), activation="relu"), @@ -118,14 +118,16 @@ def main(): assert args.cid < NUM_CLIENTS use_mnist = args.mnist - # Download CIFAR-10 dataset and partition it + # Download dataset and partition it partitions, _ = prepare_dataset(use_mnist) trainset, valset = partitions[args.cid] # Start Flower client setting its associated data partition - fl.client.start_numpy_client( + fl.client.start_client( server_address=args.server_address, - client=FlowerClient(trainset=trainset, valset=valset, use_mnist=use_mnist), + client=FlowerClient( + trainset=trainset, valset=valset, use_mnist=use_mnist + ).to_client(), ) diff --git a/examples/embedded-devices/requirements_pytorch.txt b/examples/embedded-devices/requirements_pytorch.txt index 797ca6db6244..f859c4efef17 100644 --- a/examples/embedded-devices/requirements_pytorch.txt +++ b/examples/embedded-devices/requirements_pytorch.txt @@ -1,4 +1,5 @@ flwr>=1.0, <2.0 +flwr-datasets[vision]>=0.0.2, <1.0.0 torch==1.13.1 torchvision==0.14.1 tqdm==4.65.0 diff --git a/examples/embedded-devices/requirements_tf.txt b/examples/embedded-devices/requirements_tf.txt index c7068d40b9c2..ff65b9c31648 100644 --- a/examples/embedded-devices/requirements_tf.txt +++ b/examples/embedded-devices/requirements_tf.txt @@ -1,2 +1,3 @@ flwr>=1.0, <2.0 +flwr-datasets[vision]>=0.0.2, <1.0.0 tensorflow >=2.9.1, != 2.11.1 diff --git a/examples/embedded-devices/server.py b/examples/embedded-devices/server.py index 2a15f792297e..2a6194aa5088 100644 --- a/examples/embedded-devices/server.py +++ b/examples/embedded-devices/server.py @@ -30,16 +30,11 @@ default=2, help="Minimum number of available clients required for sampling (default: 2)", ) -parser.add_argument( - "--mnist", - action="store_true", - help="If you use Raspberry Pi Zero clients (which just have 512MB or RAM) use MNIST", -) # Define metric aggregation function def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: - """Thist function averages teh `accuracy` metric sent by the clients in a `evaluate` + """This function averages teh `accuracy` metric sent by the clients in a `evaluate` stage (i.e. clients received the global model and evaluate it on their local validation sets).""" # Multiply accuracy of each client by number of examples used diff --git a/examples/federated-kaplan-meier-fitter/README.md b/examples/federated-kaplan-meier-fitter/README.md new file mode 100644 index 000000000000..1569467d6f82 --- /dev/null +++ b/examples/federated-kaplan-meier-fitter/README.md @@ -0,0 +1,104 @@ +# Flower Example using KaplanMeierFitter + +This is an introductory example on **federated survival analysis** using [Flower](https://flower.ai/) +and [lifelines](https://lifelines.readthedocs.io/en/stable/index.html) library. + +The aim of this example is to estimate the survival function using the +[Kaplan-Meier Estimate](https://en.wikipedia.org/wiki/Kaplan%E2%80%93Meier_estimator) implemented in +lifelines library (see [KaplanMeierFitter](https://lifelines.readthedocs.io/en/stable/fitters/univariate/KaplanMeierFitter.html#lifelines.fitters.kaplan_meier_fitter.KaplanMeierFitter)). The distributed/federated aspect of this example +is the data sending to the server. You can think of it as a federated analytics example. However, it's worth noting that this procedure violates privacy since the raw data is exchanged. + +Finally, many other estimators beyond KaplanMeierFitter can be used with the provided strategy: +AalenJohansenFitter, GeneralizedGammaFitter, LogLogisticFitter, +SplineFitter, and WeibullFitter. + +We also use the [NatualPartitioner](https://flower.ai/docs/datasets/ref-api/flwr_datasets.partitioner.NaturalIdPartitioner.html#flwr_datasets.partitioner.NaturalIdPartitioner) from [Flower Datasets](https://flower.ai/docs/datasets/) to divide the data according to +the group it comes from therefore to simulate the division that might occur. + +

+Survival Function +

+ +## Project Setup + +Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you: + +```shell +$ git clone --depth=1 https://github.com/adap/flower.git _tmp && mv _tmp/examples/federated-kaplan-meier-fitter . && rm -rf _tmp && cd federated-kaplan-meier-fitter +``` + +This will create a new directory called `federated-kaplan-meier-fitter` containing the following files: + +```shell +-- pyproject.toml +-- requirements.txt +-- client.py +-- server.py +-- centralized.py +-- README.md +``` + +### Installing Dependencies + +Project dependencies (such as `lifelines` and `flwr`) are defined in `pyproject.toml` and `requirements.txt`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)) or [pip](https://pip.pypa.io/en/latest/development/), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences. + +#### Poetry + +```shell +poetry install +poetry shell +``` + +Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command: + +```shell +poetry run python3 -c "import flwr" +``` + +If you don't see any errors you're good to go! + +#### pip + +Write the command below in your terminal to install the dependencies according to the configuration file requirements.txt. + +```shell +pip install -r requirements.txt +``` + +## Run Federated Survival Analysis with Flower and lifelines's KaplanMeierFitter + +### Start the long-running Flower server (SuperLink) + +```bash +flower-superlink --insecure +``` + +### Start the long-running Flower client (SuperNode) + +In a new terminal window, start the first long-running Flower client: + +```bash +flower-client-app client:node_1_app --insecure +``` + +In yet another new terminal window, start the second long-running Flower client: + +```bash +flower-client-app client:node_2_app --insecure +``` + +### Run the Flower App + +With both the long-running server (SuperLink) and two clients (SuperNode) up and running, we can now run the actual Flower App: + +```bash +flower-server-app server:app --insecure +``` + +You will see that the server is printing survival function, median survival time and saves the plot with the survival function. + +You can also check that the results match the centralized version. + +```shell +$ python3 centralized.py +``` diff --git a/examples/federated-kaplan-meier-fitter/_static/survival_function_centralized.png b/examples/federated-kaplan-meier-fitter/_static/survival_function_centralized.png new file mode 100644 index 000000000000..b7797f0879f2 Binary files /dev/null and b/examples/federated-kaplan-meier-fitter/_static/survival_function_centralized.png differ diff --git a/examples/federated-kaplan-meier-fitter/_static/survival_function_federated.png b/examples/federated-kaplan-meier-fitter/_static/survival_function_federated.png new file mode 100644 index 000000000000..b7797f0879f2 Binary files /dev/null and b/examples/federated-kaplan-meier-fitter/_static/survival_function_federated.png differ diff --git a/examples/federated-kaplan-meier-fitter/centralized.py b/examples/federated-kaplan-meier-fitter/centralized.py new file mode 100644 index 000000000000..c94edfef9a18 --- /dev/null +++ b/examples/federated-kaplan-meier-fitter/centralized.py @@ -0,0 +1,16 @@ +import matplotlib.pyplot as plt +from lifelines import KaplanMeierFitter +from lifelines.datasets import load_waltons + +if __name__ == "__main__": + X = load_waltons() + fitter = KaplanMeierFitter() + fitter.fit(X["T"], X["E"]) + print("Survival function") + print(fitter.survival_function_) + print("Mean survival time:") + print(fitter.median_survival_time_) + fitter.plot_survival_function() + plt.title("Survival function of fruit flies (Walton's data)", fontsize=16) + plt.savefig("./_static/survival_function_centralized.png", dpi=200) + print("Centralized survival function saved.") diff --git a/examples/federated-kaplan-meier-fitter/client.py b/examples/federated-kaplan-meier-fitter/client.py new file mode 100644 index 000000000000..948492efc575 --- /dev/null +++ b/examples/federated-kaplan-meier-fitter/client.py @@ -0,0 +1,65 @@ +from typing import Dict, List, Tuple + +import flwr as fl +import numpy as np +from datasets import Dataset +from flwr.common import NDArray, NDArrays +from flwr_datasets.partitioner import NaturalIdPartitioner +from lifelines.datasets import load_waltons + + +class FlowerClient(fl.client.NumPyClient): + """Flower client that holds and sends the events and times data. + + Parameters + ---------- + times: NDArray + Times of the `events`. + events: NDArray + Events represented by 0 - no event, 1 - event occurred. + + Raises + ------ + ValueError + If the `times` and `events` are not the same shape. + """ + + def __init__(self, times: NDArray, events: NDArray): + if len(times) != len(events): + raise ValueError("The times and events arrays have to be same shape.") + self._times = times + self._events = events + + def fit( + self, parameters: List[np.ndarray], config: Dict[str, str] + ) -> Tuple[NDArrays, int, Dict]: + return ( + [self._times, self._events], + len(self._times), + {}, + ) + + +# Prepare data +X = load_waltons() +partitioner = NaturalIdPartitioner(partition_by="group") +partitioner.dataset = Dataset.from_pandas(X) + + +def get_client_fn(partition_id: int): + def client_fn(cid: str): + partition = partitioner.load_partition(partition_id).to_pandas() + events = partition["E"].values + times = partition["T"].values + return FlowerClient(times=times, events=events).to_client() + + return client_fn + + +# Run via `flower-client-app client:app` +node_1_app = fl.client.ClientApp( + client_fn=get_client_fn(0), +) +node_2_app = fl.client.ClientApp( + client_fn=get_client_fn(1), +) diff --git a/examples/federated-kaplan-meier-fitter/pyproject.toml b/examples/federated-kaplan-meier-fitter/pyproject.toml new file mode 100644 index 000000000000..8fe354ffb750 --- /dev/null +++ b/examples/federated-kaplan-meier-fitter/pyproject.toml @@ -0,0 +1,18 @@ +[build-system] +requires = ["poetry-core>=1.4.0"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "federated-kaplan-meier-fitter" +version = "0.1.0" +description = "Federated Kaplan Meier Fitter with Flower" +authors = ["The Flower Authors "] +maintainers = ["The Flower Authors "] + +[tool.poetry.dependencies] +python = ">=3.9,<3.11" +flwr-nightly = "*" +flwr-datasets = ">=0.0.2,<1.0.0" +numpy = ">=1.23.2" +pandas = ">=2.0.0" +lifelines = ">=0.28.0" diff --git a/examples/federated-kaplan-meier-fitter/requirements.txt b/examples/federated-kaplan-meier-fitter/requirements.txt new file mode 100644 index 000000000000..cc8146545c7b --- /dev/null +++ b/examples/federated-kaplan-meier-fitter/requirements.txt @@ -0,0 +1,5 @@ +flwr-nightly +flwr-datasets>=0.0.2, <1.0.0 +numpy>=1.23.2 +pandas>=2.0.0 +lifelines>=0.28.0 diff --git a/examples/federated-kaplan-meier-fitter/server.py b/examples/federated-kaplan-meier-fitter/server.py new file mode 100644 index 000000000000..141504ab59c0 --- /dev/null +++ b/examples/federated-kaplan-meier-fitter/server.py @@ -0,0 +1,145 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Strategy that supports many univariate fitters from lifelines library.""" + +from typing import Dict, List, Optional, Tuple, Union, Any + +import numpy as np +import flwr as fl +import matplotlib.pyplot as plt +from flwr.common import ( + FitIns, + Parameters, + Scalar, + EvaluateRes, + EvaluateIns, + FitRes, + parameters_to_ndarrays, +) +from flwr.server.client_manager import ClientManager +from flwr.server.client_proxy import ClientProxy +from flwr.server.strategy import Strategy +from lifelines import KaplanMeierFitter + + +class EventTimeFitterStrategy(Strategy): + """Federated strategy to aggregate the data that consist of events and times. + + It works with the following uni-variate fitters from the lifelines library: + AalenJohansenFitter, GeneralizedGammaFitter, KaplanMeierFitter, LogLogisticFitter, + SplineFitter, WeibullFitter. Note that each of them might require slightly different + initialization but constructed fitter object are required to be passed. + + This strategy recreates the event and time data based on the data received from the + nodes. + + Parameters + ---------- + min_num_clients: int + pass + fitter: Any + uni-variate fitter from lifelines library that works with event, time data e.g. + KaplanMeierFitter + """ + + def __init__(self, min_num_clients: int, fitter: Any): + # Fitter can be access after the federated training ends + self._min_num_clients = min_num_clients + self.fitter = fitter + + def configure_fit( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, FitIns]]: + """Configure the fit method.""" + config = {} + fit_ins = FitIns(parameters, config) + clients = client_manager.sample( + num_clients=client_manager.num_available(), + min_num_clients=self._min_num_clients, + ) + return [(client, fit_ins) for client in clients] + + def aggregate_fit( + self, + server_round: int, + results: List[Tuple[ClientProxy, FitRes]], + failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], + ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + """Merge data and perform the fitting of the fitter from lifelines library. + + Assume just a single federated learning round. Assume the data comes as a list + with two elements of 1-dim numpy arrays of events and times. + """ + remote_data = [ + (parameters_to_ndarrays(fit_res.parameters)) for _, fit_res in results + ] + + combined_times = remote_data[0][0] + combined_events = remote_data[0][1] + + for te_list in remote_data[1:]: + combined_times = np.concatenate((combined_times, te_list[0])) + combined_events = np.concatenate((combined_events, te_list[1])) + + args_sorted = np.argsort(combined_times) + sorted_times = combined_times[args_sorted] + sorted_events = combined_events[args_sorted] + self.fitter.fit(sorted_times, sorted_events) + print("Survival function:") + print(self.fitter.survival_function_) + self.fitter.plot_survival_function() + plt.title("Survival function of fruit flies (Walton's data)", fontsize=16) + plt.savefig("./_static/survival_function_federated.png", dpi=200) + print("Mean survival time:") + print(self.fitter.median_survival_time_) + return None, {} + + # The methods below return None or empty results. + # They need to be implemented to since the methods are abstract in the parent class + def initialize_parameters( + self, client_manager: Optional[ClientManager] = None + ) -> Optional[Parameters]: + """No parameter initialization is needed.""" + return None + + def evaluate( + self, server_round: int, parameters: Parameters + ) -> Optional[Tuple[float, Dict[str, Scalar]]]: + """No centralized evaluation.""" + return None + + def aggregate_evaluate( + self, + server_round: int, + results: List[Tuple[ClientProxy, EvaluateRes]], + failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> Tuple[Optional[float], Dict[str, Scalar]]: + """No federated evaluation.""" + return None, {} + + def configure_evaluate( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, EvaluateIns]]: + """No federated evaluation.""" + return [] + + +fitter = KaplanMeierFitter() # You can choose other method that work on E, T data +strategy = EventTimeFitterStrategy(min_num_clients=2, fitter=fitter) + +app = fl.server.ServerApp( + config=fl.server.ServerConfig(num_rounds=1), + strategy=strategy, +) diff --git a/examples/flower-in-30-minutes/tutorial.ipynb b/examples/flower-in-30-minutes/tutorial.ipynb index 336ec4c19644..0e42cff924e8 100644 --- a/examples/flower-in-30-minutes/tutorial.ipynb +++ b/examples/flower-in-30-minutes/tutorial.ipynb @@ -7,11 +7,11 @@ "source": [ "Welcome to the 30 minutes Flower federated learning tutorial!\n", "\n", - "In this tutorial you will implement your first Federated Learning project using [Flower](https://flower.dev/).\n", + "In this tutorial you will implement your first Federated Learning project using [Flower](https://flower.ai/).\n", "\n", "🧑‍🏫 This tutorial starts at zero and expects no familiarity with federated learning. Only a basic understanding of data science and Python programming is assumed. A minimal understanding of ML is not required but if you already know about it, nothing is stopping your from modifying this code as you see fit!\n", "\n", - "> Star Flower on [GitHub ⭐️](https://github.com/adap/flower) and join the Flower community on Slack to connect, ask questions, and get help: [Join Slack 🌼](https://flower.dev/join-slack/). We'd love to hear from you in the #introductions channel! And if anything is unclear, head over to the #questions channel.\n", + "> Star Flower on [GitHub ⭐️](https://github.com/adap/flower) and join the Flower community on Slack to connect, ask questions, and get help: [Join Slack 🌼](https://flower.ai/join-slack/). We'd love to hear from you in the #introductions channel! And if anything is unclear, head over to the #questions channel.\n", "\n", "Let's get stated!" ] @@ -59,7 +59,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We will be using the _simulation_ model in Flower, which allows you to run a large number of clients without the overheads of manually managing devices. This is achieved via the `Virtual Client Engine`, the core component that runs [FL Simulations](https://flower.dev/docs/framework/how-to-run-simulations.html) with Flower. With simulation, you can dynamically scale your experiments whether you run the code on your laptop, a machine with a single GPU, a server with multiple GPUs os even on a cluster with multiple servers. The `Virtual Client Engine` handles everything transparently and it allows you to specify how many resources (e.g. CPU cores, GPU VRAM) should be assigned to each virtual client." + "We will be using the _simulation_ model in Flower, which allows you to run a large number of clients without the overheads of manually managing devices. This is achieved via the `Virtual Client Engine`, the core component that runs [FL Simulations](https://flower.ai/docs/framework/how-to-run-simulations.html) with Flower. With simulation, you can dynamically scale your experiments whether you run the code on your laptop, a machine with a single GPU, a server with multiple GPUs os even on a cluster with multiple servers. The `Virtual Client Engine` handles everything transparently and it allows you to specify how many resources (e.g. CPU cores, GPU VRAM) should be assigned to each virtual client." ] }, { @@ -69,7 +69,7 @@ "source": [ "## Install your ML framework\n", "\n", - "Flower is agnostic to your choice of ML Framework. Flower works with `PyTorch`, `Tensorflow`, `NumPy`, `🤗 Transformers`, `MXNet`, `JAX`, `scikit-learn`, `fastai`, `Pandas`. Flower also supports all major platforms: `iOS`, `Android` and plain `C++`. You can find a _quickstart- example for each of the above in the [Flower Repository](https://github.com/adap/flower/tree/main/examples) inside the `examples/` directory. And check the [Flower Documentation](https://flower.dev/docs/) for even more learning materials.\n", + "Flower is agnostic to your choice of ML Framework. Flower works with `PyTorch`, `Tensorflow`, `NumPy`, `🤗 Transformers`, `MXNet`, `JAX`, `scikit-learn`, `fastai`, `Pandas`. Flower also supports all major platforms: `iOS`, `Android` and plain `C++`. You can find a _quickstart- example for each of the above in the [Flower Repository](https://github.com/adap/flower/tree/main/examples) inside the `examples/` directory. And check the [Flower Documentation](https://flower.ai/docs/) for even more learning materials.\n", "\n", "In this tutorial we are going to use PyTorch, so let's install a recent version. In this tutorial we'll use a small model so using CPU only training will suffice (this will also prevent Colab from abruptly terminating your experiment if resource limits are exceeded)" ] @@ -631,11 +631,11 @@ "\n", "\n", "class FlowerClient(fl.client.NumPyClient):\n", - " def __init__(self, trainloader, vallodaer) -> None:\n", + " def __init__(self, trainloader, valloader) -> None:\n", " super().__init__()\n", "\n", " self.trainloader = trainloader\n", - " self.valloader = vallodaer\n", + " self.valloader = valloader\n", " self.model = Net(num_classes=10)\n", "\n", " def set_parameters(self, parameters):\n", @@ -714,7 +714,7 @@ "metadata": {}, "outputs": [], "source": [ - "def get_evalulate_fn(testloader):\n", + "def get_evaluate_fn(testloader):\n", " \"\"\"This is a function that returns a function. The returned\n", " function (i.e. `evaluate_fn`) will be executed by the strategy\n", " at the end of each round to evaluate the stat of the global\n", @@ -747,7 +747,7 @@ " fraction_fit=0.1, # let's sample 10% of the client each round to do local training\n", " fraction_evaluate=0.1, # after each round, let's sample 20% of the clients to asses how well the global model is doing\n", " min_available_clients=100, # total number of clients available in the experiment\n", - " evaluate_fn=get_evalulate_fn(testloader),\n", + " evaluate_fn=get_evaluate_fn(testloader),\n", ") # a callback to a function that the strategy can execute to evaluate the state of the global model on a centralised dataset" ] }, @@ -775,8 +775,8 @@ " \"\"\"Returns a FlowerClient containing the cid-th data partition\"\"\"\n", "\n", " return FlowerClient(\n", - " trainloader=trainloaders[int(cid)], vallodaer=valloaders[int(cid)]\n", - " )\n", + " trainloader=trainloaders[int(cid)], valloader=valloaders[int(cid)]\n", + " ).to_client()\n", "\n", " return client_fn\n", "\n", @@ -853,21 +853,21 @@ "\n", "Well, if you enjoyed this content, consider giving us a ⭐️ on GitHub -> https://github.com/adap/flower\n", "\n", - "* **[DOCS]** How about running your Flower clients on the GPU? find out how to do it in the [Flower Simulation Documentation](https://flower.dev/docs/framework/how-to-run-simulations.html)\n", + "* **[DOCS]** How about running your Flower clients on the GPU? find out how to do it in the [Flower Simulation Documentation](https://flower.ai/docs/framework/how-to-run-simulations.html)\n", "\n", "* **[VIDEO]** You can follow our [detailed line-by-line 9-videos tutorial](https://www.youtube.com/watch?v=cRebUIGB5RU&list=PLNG4feLHqCWlnj8a_E1A_n5zr2-8pafTB) about everything you need to know to design your own Flower Simulation pipelines\n", "\n", "* Check more advanced simulation examples the Flower GitHub:\n", "\n", " * Flower simulation with Tensorflow/Keras: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://github.com/adap/flower/tree/main/examples/simulation-tensorflow)\n", - " \n", + "\n", " * Flower simulation with Pytorch: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://github.com/adap/flower/tree/main/examples/simulation-pytorch)\n", "\n", - "* **[DOCS]** All Flower examples: https://flower.dev/docs/examples/\n", + "* **[DOCS]** All Flower examples: https://flower.ai/docs/examples/\n", "\n", "* **[VIDEO]** Our Youtube channel: https://www.youtube.com/@flowerlabs\n", "\n", - "Don't forget to join our Slack channel: https://flower.dev/join-slack/\n" + "Don't forget to join our Slack channel: https://flower.ai/join-slack/\n" ] }, { diff --git a/examples/flower-simulation-step-by-step-pytorch/Part-I/client.py b/examples/flower-simulation-step-by-step-pytorch/Part-I/client.py index b10c32d36c42..eac831ad1932 100644 --- a/examples/flower-simulation-step-by-step-pytorch/Part-I/client.py +++ b/examples/flower-simulation-step-by-step-pytorch/Part-I/client.py @@ -98,7 +98,7 @@ def client_fn(cid: str): trainloader=trainloaders[int(cid)], vallodaer=valloaders[int(cid)], num_classes=num_classes, - ) + ).to_client() # return the function to spawn client return client_fn diff --git a/examples/flower-simulation-step-by-step-pytorch/Part-II/client.py b/examples/flower-simulation-step-by-step-pytorch/Part-II/client.py index d269d4892a0e..7da9547d7362 100644 --- a/examples/flower-simulation-step-by-step-pytorch/Part-II/client.py +++ b/examples/flower-simulation-step-by-step-pytorch/Part-II/client.py @@ -75,6 +75,6 @@ def client_fn(cid: str): trainloader=trainloaders[int(cid)], vallodaer=valloaders[int(cid)], model_cfg=model_cfg, - ) + ).to_client() return client_fn diff --git a/examples/flower-simulation-step-by-step-pytorch/README.md b/examples/flower-simulation-step-by-step-pytorch/README.md index 55b8d837b090..beb8dd7f6f95 100644 --- a/examples/flower-simulation-step-by-step-pytorch/README.md +++ b/examples/flower-simulation-step-by-step-pytorch/README.md @@ -1,5 +1,7 @@ # Flower Simulation Step-by-Step +> Since this tutorial (and its video series) was put together, Flower has been updated a few times. As a result, some of the steps to construct the environment (see below) have been updated. Some parts of the code have also been updated. Overall, the content of this tutorial and how things work remains the same as in the video tutorials. + This directory contains the code developed in the `Flower Simulation` tutorial series on Youtube. You can find all the videos [here](https://www.youtube.com/playlist?list=PLNG4feLHqCWlnj8a_E1A_n5zr2-8pafTB) or clicking on the video preview below. - In `Part-I` (7 videos) we developed from scratch a complete Federated Learning pipeline for simulation using PyTorch. @@ -19,20 +21,17 @@ As presented in the video, we first need to create a Python environment. You are # I'm assuming you are running this on an Ubuntu 22.04 machine (GPU is not required) # create the environment -conda create -n flower_tutorial python=3.8 -y +conda create -n flower_tutorial python=3.9 -y # activate your environment (depending on how you installed conda you might need to use `conda activate ...` instead) source activate flower_tutorial # install PyToch (other versions would likely work) conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia -y -# conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 -c pytorch # If you don't have a GPU +# conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 -c pytorch -y # If you don't have a GPU -# install flower (for FL) and hydra (for configs) -pip install flwr==1.4.0 hydra-core==1.3.2 -# install ray -# you might see some warning messages after installing it (you can ignore them) -pip install ray==1.11.1 +# Install Flower and other dependencies +pip install -r requirements.txt ``` If you are running this on macOS with Apple Silicon (i.e. M1, M2), you'll need a different `grpcio` package if you see an error when running the code. To fix this do: diff --git a/examples/flower-simulation-step-by-step-pytorch/requirements.txt b/examples/flower-simulation-step-by-step-pytorch/requirements.txt new file mode 100644 index 000000000000..a322192ca711 --- /dev/null +++ b/examples/flower-simulation-step-by-step-pytorch/requirements.txt @@ -0,0 +1,2 @@ +flwr[simulation]>=1.0, <2.0 +hydra-core==1.3.2 \ No newline at end of file diff --git a/examples/flower-via-docker-compose/.gitignore b/examples/flower-via-docker-compose/.gitignore new file mode 100644 index 000000000000..de5b2e7692bf --- /dev/null +++ b/examples/flower-via-docker-compose/.gitignore @@ -0,0 +1,17 @@ +# ignore __pycache__ directories +__pycache__/ + +# ignore .pyc files +*.pyc + +# ignore .vscode directory +.vscode/ + +# ignore .npz files +*.npz + +# ignore .csv files +*.csv + +# ignore docker-compose.yaml file +docker-compose.yml \ No newline at end of file diff --git a/examples/flower-via-docker-compose/Dockerfile b/examples/flower-via-docker-compose/Dockerfile new file mode 100644 index 000000000000..ee6fee3103a5 --- /dev/null +++ b/examples/flower-via-docker-compose/Dockerfile @@ -0,0 +1,19 @@ +# Use an official Python runtime as a parent image +FROM python:3.10-slim-buster + +# Set the working directory in the container to /app +WORKDIR /app + +# Copy the requirements file into the container +COPY ./requirements.txt /app/requirements.txt + +# Install gcc and other dependencies +RUN apt-get update && apt-get install -y \ + gcc \ + python3-dev && \ + rm -rf /var/lib/apt/lists/* + +# Install any needed packages specified in requirements.txt +RUN pip install -r requirements.txt + + diff --git a/examples/flower-via-docker-compose/README.md b/examples/flower-via-docker-compose/README.md new file mode 100644 index 000000000000..3ef1ac37bcda --- /dev/null +++ b/examples/flower-via-docker-compose/README.md @@ -0,0 +1,254 @@ +# Leveraging Flower and Docker for Device Heterogeneity Management in Federated Learning + +

+ Flower Website + Docker Logo +

+ +## Introduction + +In this example, we tackle device heterogeneity in federated learning, arising from differences in memory and CPU capabilities across devices. This diversity affects training efficiency and inclusivity. Our strategy includes simulating this heterogeneity by setting CPU and memory limits in a Docker setup, using a custom Docker compose generator script. This approach creates a varied training environment and enables us to develop strategies to manage these disparities effectively. + +## Handling Device Heterogeneity + +1. **System Metrics Access**: + + - Effective management of device heterogeneity begins with monitoring system metrics of each container. We integrate the following services to achieve this: + - **Cadvisor**: Collects comprehensive metrics from each Docker container. + - **Prometheus**: Using `prometheus.yaml` for configuration, it scrapes data from Cadvisor at scheduled intervals, serving as a robust time-series database. Users can access the Prometheus UI at `http://localhost:9090` to create and run queries using PromQL, allowing for detailed insight into container performance. + +2. **Mitigating Heterogeneity**: + + - In this basic use case, we address device heterogeneity by establishing rules tailored to each container's system capabilities. This involves modifying training parameters, such as batch sizes and learning rates, based on each device's memory capacity and CPU availability. These settings are specified in the `client_configs` array in the `create_docker_compose` script. For example: + + ```python + client_configs = [ + {"mem_limit": "3g", "batch_size": 32, "cpus": 4, "learning_rate": 0.001}, + {"mem_limit": "6g", "batch_size": 256, "cpus": 1, "learning_rate": 0.05}, + {"mem_limit": "4g", "batch_size": 64, "cpus": 3, "learning_rate": 0.02}, + {"mem_limit": "5g", "batch_size": 128, "cpus": 2.5, "learning_rate": 0.09}, + ] + ``` + +## Prerequisites + +Docker must be installed and the Docker daemon running on your server. If you don't already have Docker installed, you can get [installation instructions for your specific Linux distribution or macOS from Docker](https://docs.docker.com/engine/install/). Besides Docker, the only extra requirement is having Python installed. You don't need to create a new environment for this example since all dependencies will be installed inside Docker containers automatically. + +## Running the Example + +Running this example is easy. For a more detailed step-by-step guide, including more useful material, refer to the detailed guide in the following section. + +```bash + +# Generate docker compose file +python helpers/generate_docker_compose.py # by default will configure to use 2 clients for 100 rounds + +# Build docker images +docker-compose build + +# Launch everything +docker-compose up +``` + +On your favourite browser, go to `http://localhost:3000` to see the Graphana dashboard showing system-level and application-level metrics. + +To stop all containers, open a new terminal and `cd` into this directory, then run `docker-compose down`. Alternatively, you can do `ctrl+c` on the same terminal and then run `docker-compose down` to ensure everything is terminated. + +## Running the Example (detailed) + +### Step 1: Configure Docker Compose + +Execute the following command to run the `helpers/generate_docker_compose.py` script. This script creates the docker-compose configuration needed to set up the environment. + +```bash +python helpers/generate_docker_compose.py +``` + +Within the script, specify the number of clients (`total_clients`) and resource limitations for each client in the `client_configs` array. You can adjust the number of rounds by passing `--num_rounds` to the above command. + +### Step 2: Build and Launch Containers + +1. **Execute Initialization Script**: + + - To build the Docker images and start the containers, use the following command: + + ```bash + # this is the only command you need to execute to run the entire example + docker-compose up + ``` + + - If you make any changes to the Dockerfile or other configuration files, you should rebuild the images to reflect these changes. This can be done by adding the `--build` flag to the command: + + ```bash + docker-compose up --build + ``` + + - The `--build` flag instructs Docker Compose to rebuild the images before starting the containers, ensuring that any code or configuration changes are included. + + - To stop all services, you have two options: + + - Run `docker-compose down` in another terminal if you are in the same directory. This command will stop and remove the containers, networks, and volumes created by `docker-compose up`. + - Press `Ctrl+C` once in the terminal where `docker-compose up` is running. This will stop the containers but won't remove them or the networks and volumes they use. + +2. **Services Startup**: + + - Several services will automatically launch as defined in your `docker-compose.yml` file: + + - **Monitoring Services**: Prometheus for metrics collection, Cadvisor for container monitoring, and Grafana for data visualization. + - **Flower Federated Learning Environment**: The Flower server and client containers are initialized and start running. + + - After launching the services, verify that all Docker containers are running correctly by executing the `docker ps` command. Here's an example output: + + ```bash + ➜ ~ docker ps + CONTAINER ID IMAGE COMMAND CREATED STATUS PORTS NAMES + 9f05820eba45 flower-via-docker-compose-client2 "python client.py --…" 50 seconds ago Up 48 seconds 0.0.0.0:6002->6002/tcp client2 + a0333715d504 flower-via-docker-compose-client1 "python client.py --…" 50 seconds ago Up 48 seconds 0.0.0.0:6001->6001/tcp client1 + 0da2bf735965 flower-via-docker-compose-server "python server.py --…" 50 seconds ago Up 48 seconds 0.0.0.0:6000->6000/tcp, 0.0.0.0:8000->8000/tcp, 0.0.0.0:8265->8265/tcp server + c57ef50657ae grafana/grafana:latest "/run.sh --config=/e…" 50 seconds ago Up 49 seconds 0.0.0.0:3000->3000/tcp grafana + 4f274c2083dc prom/prometheus:latest "/bin/prometheus --c…" 50 seconds ago Up 49 seconds 0.0.0.0:9090->9090/tcp prometheus + e9f4c9644a1c gcr.io/cadvisor/cadvisor:v0.47.0 "/usr/bin/cadvisor -…" 50 seconds ago Up 49 seconds 0.0.0.0:8080->8080/tcp cadvisor + ``` + + - To monitor the resource utilization of your containers in real-time and see the limits imposed in the Docker Compose file, you can use the `docker stats` command. This command provides a live stream of container CPU, memory, and network usage statistics. + + ```bash + ➜ ~ docker stats + CONTAINER ID NAME CPU % MEM USAGE / LIMIT MEM % NET I/O BLOCK I/O PIDS + 9f05820eba45 client2 104.44% 1.968GiB / 6GiB 32.80% 148MB / 3.22MB 0B / 284MB 82 + a0333715d504 client1 184.69% 1.498GiB / 3GiB 49.92% 149MB / 2.81MB 1.37MB / 284MB 82 + 0da2bf735965 server 0.12% 218.5MiB / 15.61GiB 1.37% 1.47MB / 2.89MB 2.56MB / 2.81MB 45 + c57ef50657ae grafana 0.24% 96.19MiB / 400MiB 24.05% 18.9kB / 3.79kB 77.8kB / 152kB 20 + 4f274c2083dc prometheus 1.14% 52.73MiB / 500MiB 10.55% 6.79MB / 211kB 1.02MB / 1.31MB 15 + e9f4c9644a1c cadvisor 7.31% 32.14MiB / 500MiB 6.43% 139kB / 6.66MB 500kB / 0B 18 + ``` + +3. **Automated Grafana Configuration**: + + - Grafana is configured to load pre-defined data sources and dashboards for immediate monitoring, facilitated by provisioning files. The provisioning files include `prometheus-datasource.yml` for data sources, located in the `./config/provisioning/datasources` directory, and `dashboard_index.json` for dashboards, in the `./config/provisioning/dashboards` directory. The `grafana.ini` file is also tailored to enhance user experience: + - **Admin Credentials**: We provide default admin credentials in the `grafana.ini` configuration, which simplifies access by eliminating the need for users to go through the initial login process. + - **Default Dashboard Path**: A default dashboard path is set in `grafana.ini` to ensure that the dashboard with all the necessary panels is rendered when Grafana is accessed. + + These files and settings are directly mounted into the Grafana container via Docker Compose volume mappings. This setup guarantees that upon startup, Grafana is pre-configured for monitoring, requiring no additional manual setup. + +4. **Begin Training Process**: + + - The federated learning training automatically begins once all client containers are successfully connected to the Flower server. This synchronizes the learning process across all participating clients. + +By following these steps, you will have a fully functional federated learning environment with device heterogeneity and monitoring capabilities. + +## Model Training and Dataset Integration + +### Data Pipeline with FLWR-Datasets + +We have integrated [`flwr-datasets`](https://flower.ai/docs/datasets/) into our data pipeline, which is managed within the `load_data.py` file in the `helpers/` directory. This script facilitates standardized access to datasets across the federated network and incorporates a `data_sampling_percentage` argument. This argument allows users to specify the percentage of the dataset to be used for training and evaluation, accommodating devices with lower memory capabilities to prevent Out-of-Memory (OOM) errors. + +### Model Selection and Dataset + +For the federated learning system, we have selected the MobileNet model due to its efficiency in image classification tasks. The model is trained and evaluated on the CIFAR-10 dataset. The combination of MobileNet and CIFAR-10 is ideal for demonstrating the capabilities of our federated learning solution in a heterogeneous device environment. + +- **MobileNet**: A streamlined architecture for mobile and embedded devices that balances performance and computational cost. +- **CIFAR-10 Dataset**: A standard benchmark dataset for image classification, containing various object classes that pose a comprehensive challenge for the learning model. + +By integrating these components, our framework is well-prepared to handle the intricacies of training over a distributed network with varying device capabilities and data availability. + +## Visualizing with Grafana + +### Access Grafana Dashboard + +Visit `http://localhost:3000` to enter Grafana. The automated setup ensures that you're greeted with a series of pre-configured dashboards, including the default screen with a comprehensive set of graphs. These dashboards are ready for immediate monitoring and can be customized to suit your specific requirements. + +### Dashboard Configuration + +The `dashboard_index.json` file, located in the `./config/provisioning/dashboards` directory, serves as the backbone of our Grafana dashboard's configuration. It defines the structure and settings of the dashboard panels, which are rendered when you access Grafana. This JSON file contains the specifications for various panels such as model accuracy, CPU usage, memory utilization, and network traffic. Each panel's configuration includes the data source, queries, visualization type, and other display settings like thresholds and colors. + +For instance, in our project setup, the `dashboard_index.json` configures a panel to display the model's accuracy over time using a time-series graph, and another panel to show the CPU usage across clients using a graph that plots data points as they are received. This file is fundamental for creating a customized and informative dashboard that provides a snapshot of the federated learning system's health and performance metrics. + +By modifying the `dashboard_index.json` file, users can tailor the Grafana dashboard to include additional metrics or change the appearance and behavior of existing panels to better fit their monitoring requirements. + +### Grafana Default Dashboard + +Below is the default Grafana dashboard that users will see upon accessing Grafana: + +grafana_home_screen + +This comprehensive dashboard provides insights into various system metrics across client-server containers. It includes visualizations such as: + +- **Application Metrics**: The "Model Accuracy" graph shows an upward trend as rounds of training progress, which is a positive indicator of the model learning and improving over time. Conversely, the "Model Loss" graph trends downward, suggesting that the model is becoming more precise and making fewer mistakes as it trains. + +- **CPU Usage**: The sharp spikes in the red graph, representing "client1", indicate peak CPU usage, which is considerably higher than that of "client2" (blue graph). This difference is due to "client1" being allocated more computing resources (up to 4 CPU cores) compared to "client2", which is limited to just 1 CPU core, hence the more subdued CPU usage pattern. + +- **Memory Utilization**: Both clients are allocated a similar amount of memory, reflected in the nearly same lines for memory usage. This uniform allocation allows for a straightforward comparison of how each client manages memory under similar conditions. + +- **Network Traffic**: Monitor incoming and outgoing network traffic to each client, which is crucial for understanding data exchange volumes during federated learning cycles. + +Together, these metrics paint a detailed picture of the federated learning operation, showcasing resource usage and model performance. Such insights are invaluable for system optimization, ensuring balanced load distribution and efficient model training. + +## Comprehensive Monitoring System Integration + +### Capturing Container Metrics with cAdvisor + +cAdvisor is seamlessly integrated into our monitoring setup to capture a variety of system and container metrics, such as CPU, memory, and network usage. These metrics are vital for analyzing the performance and resource consumption of the containers in the federated learning environment. + +### Custom Metrics: Setup and Monitoring via Prometheus + +In addition to the standard metrics captured by cAdvisor, we have implemented a process to track custom metrics like model's accuracy and loss within Grafana, using Prometheus as the backbone for metric collection. + +1. **Prometheus Client Installation**: + + - We began by installing the `prometheus_client` library in our Python environment, enabling us to define and expose custom metrics that Prometheus can scrape. + +2. **Defining Metrics in Server Script**: + + - Within our `server.py` script, we have established two key Prometheus Gauge metrics, specifically tailored for monitoring our federated learning model: `model_accuracy` and `model_loss`. These custom gauges are instrumental in capturing the most recent values of the model's accuracy and loss, which are essential metrics for evaluating the model's performance. The gauges are defined as follows: + + ```python + from prometheus_client import Gauge + + accuracy_gauge = Gauge('model_accuracy', 'Current accuracy of the global model') + loss_gauge = Gauge('model_loss', 'Current loss of the global model') + ``` + +3. **Exposing Metrics via HTTP Endpoint**: + + - We leveraged the `start_http_server` function from the `prometheus_client` library to launch an HTTP server on port 8000. This server provides the `/metrics` endpoint, where the custom metrics are accessible for Prometheus scraping. The function is called at the end of the `main` method in `server.py`: + + ```python + start_http_server(8000) + ``` + +4. **Updating Metrics Recording Strategy**: + + - The core of our metrics tracking lies in the `strategy.py` file, particularly within the `aggregate_evaluate` method. This method is crucial as it's where the federated learning model's accuracy and loss values are computed after each round of training with the aggregated data from all clients. + + ```python + self.accuracy_gauge.set(accuracy_aggregated) + self.loss_gauge.set(loss_aggregated) + ``` + +5. **Configuring Prometheus Scraping**: + + - In the `prometheus.yml` file, under `scrape_configs`, we configured a new job to scrape the custom metrics from the HTTP server. This setup includes the job's name, the scraping interval, and the target server's URL. + +### Visualizing the Monitoring Architecture + +The image below depicts the Prometheus scraping process as it is configured in our monitoring setup. Within this architecture: + +- The "Prometheus server" is the central component that retrieves and stores metrics. +- "cAdvisor" and the "HTTP server" we set up to expose our custom metrics are represented as "Prometheus targets" in the diagram. cAdvisor captures container metrics, while the HTTP server serves our custom `model_accuracy` and `model_loss` metrics at the `/metrics` endpoint. +- These targets are periodically scraped by the Prometheus server, aggregating data from both system-level and custom performance metrics. +- The aggregated data is then made available to the "Prometheus web UI" and "Grafana," as shown, enabling detailed visualization and analysis through the Grafana dashboard. + +prometheus-architecture + +By incorporating these steps, we have enriched our monitoring capabilities to not only include system-level metrics but also critical performance indicators of our federated learning model. This approach is pivotal for understanding and improving the learning process. Similarly, you can apply this methodology to track any other metric that you find interesting or relevant to your specific needs. This flexibility allows for a comprehensive and customized monitoring environment, tailored to the unique aspects and requirements of your federated learning system. + +## Additional Resources + +- **Grafana Tutorials**: Explore a variety of tutorials on Grafana at [Grafana Tutorials](https://grafana.com/tutorials/). +- **Prometheus Overview**: Learn more about Prometheus at their [official documentation](https://prometheus.io/docs/introduction/overview/). +- **cAdvisor Guide**: For information on monitoring Docker containers with cAdvisor, see this [Prometheus guide](https://prometheus.io/docs/guides/cadvisor/). + +## Conclusion + +This project serves as a foundational example of managing device heterogeneity within the federated learning context, employing the Flower framework alongside Docker, Prometheus, and Grafana. It's designed to be a starting point for users to explore and further adapt to the complexities of device heterogeneity in federated learning environments. diff --git a/examples/flower-via-docker-compose/client.py b/examples/flower-via-docker-compose/client.py new file mode 100644 index 000000000000..c894143532a1 --- /dev/null +++ b/examples/flower-via-docker-compose/client.py @@ -0,0 +1,110 @@ +import os +import argparse +import flwr as fl +import tensorflow as tf +import logging +from helpers.load_data import load_data +import os +from model.model import Model + +logging.basicConfig(level=logging.INFO) # Configure logging +logger = logging.getLogger(__name__) # Create logger for the module + +# Make TensorFlow log less verbose +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + +# Parse command line arguments +parser = argparse.ArgumentParser(description="Flower client") + +parser.add_argument( + "--server_address", type=str, default="server:8080", help="Address of the server" +) +parser.add_argument( + "--batch_size", type=int, default=32, help="Batch size for training" +) +parser.add_argument( + "--learning_rate", type=float, default=0.1, help="Learning rate for the optimizer" +) +parser.add_argument("--client_id", type=int, default=1, help="Unique ID for the client") +parser.add_argument( + "--total_clients", type=int, default=2, help="Total number of clients" +) +parser.add_argument( + "--data_percentage", type=float, default=0.5, help="Portion of client data to use" +) + +args = parser.parse_args() + +# Create an instance of the model and pass the learning rate as an argument +model = Model(learning_rate=args.learning_rate) + +# Compile the model +model.compile() + + +class Client(fl.client.NumPyClient): + def __init__(self, args): + self.args = args + + logger.info("Preparing data...") + (x_train, y_train), (x_test, y_test) = load_data( + data_sampling_percentage=self.args.data_percentage, + client_id=self.args.client_id, + total_clients=self.args.total_clients, + ) + + self.x_train = x_train + self.y_train = y_train + self.x_test = x_test + self.y_test = y_test + + def get_parameters(self, config): + # Return the parameters of the model + return model.get_model().get_weights() + + def fit(self, parameters, config): + # Set the weights of the model + model.get_model().set_weights(parameters) + + # Train the model + history = model.get_model().fit( + self.x_train, self.y_train, batch_size=self.args.batch_size + ) + + # Calculate evaluation metric + results = { + "accuracy": float(history.history["accuracy"][-1]), + } + + # Get the parameters after training + parameters_prime = model.get_model().get_weights() + + # Directly return the parameters and the number of examples trained on + return parameters_prime, len(self.x_train), results + + def evaluate(self, parameters, config): + # Set the weights of the model + model.get_model().set_weights(parameters) + + # Evaluate the model and get the loss and accuracy + loss, accuracy = model.get_model().evaluate( + self.x_test, self.y_test, batch_size=self.args.batch_size + ) + + # Return the loss, the number of examples evaluated on and the accuracy + return float(loss), len(self.x_test), {"accuracy": float(accuracy)} + + +# Function to Start the Client +def start_fl_client(): + try: + client = Client(args).to_client() + fl.client.start_client(server_address=args.server_address, client=client) + except Exception as e: + logger.error("Error starting FL client: %s", e) + return {"status": "error", "message": str(e)} + + +if __name__ == "__main__": + # Call the function to start the client + start_fl_client() diff --git a/examples/flower-via-docker-compose/config/grafana.ini b/examples/flower-via-docker-compose/config/grafana.ini new file mode 100644 index 000000000000..775f39d7ec22 --- /dev/null +++ b/examples/flower-via-docker-compose/config/grafana.ini @@ -0,0 +1,12 @@ +[security] +allow_embedding = true +admin_user = admin +admin_password = admin + +[dashboards] +default_home_dashboard_path = /etc/grafana/provisioning/dashboards/dashboard_index.json + +[auth.anonymous] +enabled = true +org_name = Main Org. +org_role = Admin diff --git a/examples/flower-via-docker-compose/config/prometheus.yml b/examples/flower-via-docker-compose/config/prometheus.yml new file mode 100644 index 000000000000..46cf07b9dcee --- /dev/null +++ b/examples/flower-via-docker-compose/config/prometheus.yml @@ -0,0 +1,19 @@ + +global: + scrape_interval: 1s + evaluation_interval: 1s + +rule_files: +scrape_configs: + - job_name: 'cadvisor' + scrape_interval: 1s + metrics_path: '/metrics' + static_configs: + - targets: ['cadvisor:8080'] + labels: + group: 'cadvisor' + - job_name: 'server_metrics' + scrape_interval: 1s + metrics_path: '/metrics' + static_configs: + - targets: ['server:8000'] \ No newline at end of file diff --git a/examples/flower-via-docker-compose/config/provisioning/dashboards/dashboard_index.json b/examples/flower-via-docker-compose/config/provisioning/dashboards/dashboard_index.json new file mode 100644 index 000000000000..b52f19c57508 --- /dev/null +++ b/examples/flower-via-docker-compose/config/provisioning/dashboards/dashboard_index.json @@ -0,0 +1,1051 @@ +{ + "annotations": { + "list": [ + { + "builtIn": 1, + "datasource": { + "type": "datasource", + "uid": "grafana" + }, + "enable": true, + "hide": true, + "iconColor": "rgba(0, 211, 255, 1)", + "name": "Annotations & Alerts", + "type": "dashboard" + } + ] + }, + "description": "Simple exporter for cadvisor only", + "editable": true, + "fiscalYearStartMonth": 0, + "gnetId": 14282, + "graphTooltip": 0, + "id": 12, + "links": [], + "liveNow": false, + "panels": [ + { + "gridPos": { + "h": 1, + "w": 24, + "x": 0, + "y": 0 + }, + "id": 22, + "title": "Application metrics", + "type": "row" + }, + { + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "description": "Averaged federated accuracy across clients", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "smooth", + "lineStyle": { + "fill": "solid" + }, + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 1 + }, + "id": 23, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": false + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "disableTextWrap": false, + "editorMode": "builder", + "expr": "model_accuracy", + "fullMetaSearch": false, + "includeNullMetadata": true, + "instant": false, + "legendFormat": "__auto", + "range": true, + "refId": "A", + "useBackend": false + } + ], + "title": "Model Accuracy", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "description": "Averaged Federated Loss across clients", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "smooth", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 1 + }, + "id": 21, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": false + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "disableTextWrap": false, + "editorMode": "builder", + "expr": "model_loss", + "fullMetaSearch": false, + "includeNullMetadata": true, + "instant": false, + "legendFormat": "__auto", + "range": true, + "refId": "A", + "useBackend": false + } + ], + "title": "Model Loss", + "type": "timeseries" + }, + { + "collapsed": false, + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "gridPos": { + "h": 1, + "w": 24, + "x": 0, + "y": 9 + }, + "id": 8, + "panels": [], + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "refId": "A" + } + ], + "title": "CPU", + "type": "row" + }, + { + "aliasColors": { + "client1": "red", + "client2": "blue", + "server": "yellow" + }, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "fill": 1, + "fillGradient": 0, + "gridPos": { + "h": 7, + "w": 24, + "x": 0, + "y": 10 + }, + "hiddenSeries": false, + "id": 15, + "legend": { + "alignAsTable": true, + "avg": true, + "current": false, + "max": true, + "min": false, + "rightSide": true, + "show": true, + "total": false, + "values": true + }, + "lines": true, + "linewidth": 1, + "nullPointMode": "null as zero", + "options": { + "alertThreshold": true + }, + "percentage": false, + "pluginVersion": "10.2.2", + "pointradius": 2, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "editorMode": "code", + "expr": "sum(rate(container_cpu_usage_seconds_total{instance=~\"$host\",name=~\"$container\",name=~\".+\", name !~ \"(prometheus|cadvisor|grafana)\"}[10s])) by (name) *100", + "hide": false, + "interval": "", + "legendFormat": "{{name}}", + "range": true, + "refId": "A" + } + ], + "thresholds": [], + "timeRegions": [], + "title": "CPU Usage", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "mode": "time", + "show": true, + "values": [] + }, + "yaxes": [ + { + "$$hashKey": "object:606", + "format": "percent", + "logBase": 1, + "show": true + }, + { + "$$hashKey": "object:607", + "format": "short", + "logBase": 1, + "show": true + } + ], + "yaxis": { + "align": false + } + }, + { + "collapsed": false, + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "gridPos": { + "h": 1, + "w": 24, + "x": 0, + "y": 17 + }, + "id": 11, + "panels": [], + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "refId": "A" + } + ], + "title": "Memory", + "type": "row" + }, + { + "aliasColors": { + "client1": "red", + "client2": "blue", + "server": "yellow" + }, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "fill": 1, + "fillGradient": 0, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 18 + }, + "hiddenSeries": false, + "id": 9, + "legend": { + "alignAsTable": true, + "avg": true, + "current": false, + "max": true, + "min": false, + "rightSide": true, + "show": true, + "total": false, + "values": true + }, + "lines": true, + "linewidth": 1, + "nullPointMode": "null as zero", + "options": { + "alertThreshold": true + }, + "percentage": false, + "pluginVersion": "10.2.2", + "pointradius": 2, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "editorMode": "code", + "expr": "sum(container_memory_rss{instance=~\"$host\",name=~\"$container\",name=~\".+\", name !~ \"(prometheus|cadvisor|grafana)\"}) by (name)", + "hide": false, + "interval": "", + "legendFormat": "{{name}}", + "range": true, + "refId": "A" + } + ], + "thresholds": [], + "timeRegions": [], + "title": "Memory Usage", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "mode": "time", + "show": true, + "values": [] + }, + "yaxes": [ + { + "$$hashKey": "object:606", + "format": "bytes", + "logBase": 1, + "show": true + }, + { + "$$hashKey": "object:607", + "format": "short", + "logBase": 1, + "show": true + } + ], + "yaxis": { + "align": false + } + }, + { + "aliasColors": { + "client1": "red", + "client2": "blue", + "server": "yellow" + }, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "fill": 1, + "fillGradient": 0, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 18 + }, + "hiddenSeries": false, + "id": 14, + "legend": { + "alignAsTable": true, + "avg": true, + "current": false, + "max": true, + "min": false, + "rightSide": true, + "show": true, + "total": false, + "values": true + }, + "lines": true, + "linewidth": 1, + "nullPointMode": "null as zero", + "options": { + "alertThreshold": true + }, + "percentage": false, + "pluginVersion": "10.2.2", + "pointradius": 2, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "editorMode": "code", + "expr": "sum(container_memory_cache{instance=~\"$host\",name=~\"$container\",name=~\".+\", name !~ \"(prometheus|cadvisor|grafana)\"}) by (name)", + "hide": false, + "interval": "", + "legendFormat": "{{name}}", + "range": true, + "refId": "A" + } + ], + "thresholds": [], + "timeRegions": [], + "title": "Memory Cached", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "mode": "time", + "show": true, + "values": [] + }, + "yaxes": [ + { + "$$hashKey": "object:606", + "format": "bytes", + "logBase": 1, + "show": true + }, + { + "$$hashKey": "object:607", + "format": "short", + "logBase": 1, + "show": true + } + ], + "yaxis": { + "align": false + } + }, + { + "collapsed": false, + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "gridPos": { + "h": 1, + "w": 24, + "x": 0, + "y": 26 + }, + "id": 2, + "panels": [], + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "refId": "A" + } + ], + "title": "Network", + "type": "row" + }, + { + "aliasColors": { + "client1": "red", + "client2": "blue", + "server": "yellow" + }, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "fill": 1, + "fillGradient": 0, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 27 + }, + "hiddenSeries": false, + "id": 4, + "legend": { + "alignAsTable": true, + "avg": true, + "current": false, + "hideEmpty": false, + "hideZero": false, + "max": true, + "min": false, + "rightSide": true, + "show": true, + "total": false, + "values": true + }, + "lines": true, + "linewidth": 1, + "nullPointMode": "null", + "options": { + "alertThreshold": true + }, + "percentage": false, + "pluginVersion": "10.2.2", + "pointradius": 2, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "editorMode": "code", + "expr": "sum(rate(container_network_receive_bytes_total{instance=~\"$host\",name=~\"$container\",name=~\".+\", name !~ \"(prometheus|cadvisor|grafana)\"}[10s])) by (name)", + "hide": false, + "interval": "", + "legendFormat": "{{name}}", + "range": true, + "refId": "A" + } + ], + "thresholds": [], + "timeRegions": [], + "title": "Received Network Traffic", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "mode": "time", + "show": true, + "values": [] + }, + "yaxes": [ + { + "$$hashKey": "object:674", + "format": "Bps", + "logBase": 1, + "show": true + }, + { + "$$hashKey": "object:675", + "format": "short", + "logBase": 1, + "show": true + } + ], + "yaxis": { + "align": false + } + }, + { + "aliasColors": { + "client1": "red", + "client2": "blue", + "server": "yellow" + }, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "fill": 1, + "fillGradient": 0, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 27 + }, + "hiddenSeries": false, + "id": 6, + "legend": { + "alignAsTable": true, + "avg": true, + "current": false, + "max": true, + "min": false, + "rightSide": true, + "show": true, + "total": false, + "values": true + }, + "lines": true, + "linewidth": 1, + "nullPointMode": "null", + "options": { + "alertThreshold": true + }, + "percentage": false, + "pluginVersion": "10.2.2", + "pointradius": 2, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "editorMode": "code", + "expr": "sum(rate(container_network_transmit_bytes_total{instance=~\"$host\",name=~\"$container\",name=~\".+\", name !~ \"(prometheus|cadvisor|grafana)\"}[10s])) by (name)", + "interval": "", + "legendFormat": "{{name}}", + "range": true, + "refId": "A" + } + ], + "thresholds": [], + "timeRegions": [], + "title": "Sent Network Traffic", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "mode": "time", + "show": true, + "values": [] + }, + "yaxes": [ + { + "$$hashKey": "object:832", + "format": "Bps", + "logBase": 1, + "show": true + }, + { + "$$hashKey": "object:833", + "format": "short", + "logBase": 1, + "show": true + } + ], + "yaxis": { + "align": false + } + }, + { + "collapsed": false, + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "gridPos": { + "h": 1, + "w": 24, + "x": 0, + "y": 35 + }, + "id": 19, + "panels": [], + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "refId": "A" + } + ], + "title": "Misc", + "type": "row" + }, + { + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "fieldConfig": { + "defaults": { + "custom": { + "align": "auto", + "cellOptions": { + "type": "auto" + }, + "filterable": false, + "inspect": false + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green" + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [ + { + "matcher": { + "id": "byName", + "options": "id" + }, + "properties": [ + { + "id": "custom.width", + "value": 260 + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "Running" + }, + "properties": [ + { + "id": "unit", + "value": "d" + }, + { + "id": "decimals", + "value": 1 + }, + { + "id": "custom.cellOptions", + "value": { + "type": "color-text" + } + }, + { + "id": "color", + "value": { + "fixedColor": "dark-green", + "mode": "fixed" + } + } + ] + } + ] + }, + "gridPos": { + "h": 10, + "w": 24, + "x": 0, + "y": 36 + }, + "id": 17, + "options": { + "cellHeight": "sm", + "footer": { + "countRows": false, + "fields": "", + "reducer": [ + "sum" + ], + "show": false + }, + "showHeader": true, + "sortBy": [] + }, + "pluginVersion": "10.2.2", + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "expr": "(time() - container_start_time_seconds{instance=~\"$host\",name=~\"$container\",name=~\".+\"})/86400", + "format": "table", + "instant": true, + "interval": "", + "legendFormat": "{{name}}", + "refId": "A" + } + ], + "title": "Containers Info", + "transformations": [ + { + "id": "filterFieldsByName", + "options": { + "include": { + "names": [ + "container_label_com_docker_compose_project", + "container_label_com_docker_compose_project_working_dir", + "image", + "instance", + "name", + "Value", + "container_label_com_docker_compose_service" + ] + } + } + }, + { + "id": "organize", + "options": { + "excludeByName": {}, + "indexByName": {}, + "renameByName": { + "Value": "Running", + "container_label_com_docker_compose_project": "Label", + "container_label_com_docker_compose_project_working_dir": "Working dir", + "container_label_com_docker_compose_service": "Service", + "image": "Registry Image", + "instance": "Instance", + "name": "Name" + } + } + } + ], + "type": "table" + } + ], + "refresh": "auto", + "schemaVersion": 38, + "tags": [], + "templating": { + "list": [ + { + "allValue": ".*", + "current": { + "selected": false, + "text": "All", + "value": "$__all" + }, + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "definition": "label_values({__name__=~\"container.*\"},instance)", + "hide": 0, + "includeAll": true, + "label": "Host", + "multi": false, + "name": "host", + "options": [], + "query": { + "query": "label_values({__name__=~\"container.*\"},instance)", + "refId": "Prometheus-host-Variable-Query" + }, + "refresh": 1, + "regex": "", + "skipUrlSync": false, + "sort": 5, + "tagValuesQuery": "", + "tagsQuery": "", + "type": "query", + "useTags": false + }, + { + "allValue": ".*", + "current": { + "selected": false, + "text": "All", + "value": "$__all" + }, + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "definition": "label_values({__name__=~\"container.*\", instance=~\"$host\"},name)", + "hide": 0, + "includeAll": true, + "label": "Container", + "multi": false, + "name": "container", + "options": [], + "query": { + "query": "label_values({__name__=~\"container.*\", instance=~\"$host\"},name)", + "refId": "Prometheus-container-Variable-Query" + }, + "refresh": 1, + "regex": "", + "skipUrlSync": false, + "sort": 0, + "tagValuesQuery": "", + "tagsQuery": "", + "type": "query", + "useTags": false + } + ] + }, + "time": { + "from": "now-15m", + "to": "now" + }, + "timepicker": {}, + "timezone": "", + "title": "Cadvisor exporter Copy", + "uid": "fcf2a8da-792c-4b9f-a22f-876820b53c2f", + "version": 2, + "weekStart": "" +} \ No newline at end of file diff --git a/examples/flower-via-docker-compose/config/provisioning/dashboards/dashboards.yml b/examples/flower-via-docker-compose/config/provisioning/dashboards/dashboards.yml new file mode 100644 index 000000000000..e0d542f58f2b --- /dev/null +++ b/examples/flower-via-docker-compose/config/provisioning/dashboards/dashboards.yml @@ -0,0 +1,12 @@ +apiVersion: 1 + +providers: +- name: 'default' + orgId: 1 + folder: '' + type: file + disableDeletion: false + editable: true + updateIntervalSeconds: 5 # Optional: How often Grafana will scan for changed dashboards + options: + path: /etc/grafana/provisioning/dashboards diff --git a/examples/flower-via-docker-compose/config/provisioning/datasources/prometheus-datasource.yml b/examples/flower-via-docker-compose/config/provisioning/datasources/prometheus-datasource.yml new file mode 100644 index 000000000000..7c8ce00fdcdc --- /dev/null +++ b/examples/flower-via-docker-compose/config/provisioning/datasources/prometheus-datasource.yml @@ -0,0 +1,9 @@ +apiVersion: 1 + +datasources: +- name: Prometheus + type: prometheus + access: proxy + uid: db69454e-e558-479e-b4fc-80db52bf91da + url: http://host.docker.internal:9090 + isDefault: true diff --git a/examples/flower-via-docker-compose/helpers/generate_docker_compose.py b/examples/flower-via-docker-compose/helpers/generate_docker_compose.py new file mode 100644 index 000000000000..cde553a95e68 --- /dev/null +++ b/examples/flower-via-docker-compose/helpers/generate_docker_compose.py @@ -0,0 +1,147 @@ +import random +import argparse + +parser = argparse.ArgumentParser(description="Generated Docker Compose") +parser.add_argument( + "--total_clients", type=int, default=2, help="Total clients to spawn (default: 2)" +) +parser.add_argument( + "--num_rounds", type=int, default=100, help="Number of FL rounds (default: 100)" +) +parser.add_argument( + "--data_percentage", + type=float, + default=0.6, + help="Portion of client data to use (default: 0.6)", +) +parser.add_argument( + "--random", action="store_true", help="Randomize client configurations" +) + + +def create_docker_compose(args): + # cpus is used to set the number of CPUs available to the container as a fraction of the total number of CPUs on the host machine. + # mem_limit is used to set the memory limit for the container. + client_configs = [ + {"mem_limit": "3g", "batch_size": 32, "cpus": 4, "learning_rate": 0.001}, + {"mem_limit": "6g", "batch_size": 256, "cpus": 1, "learning_rate": 0.05}, + {"mem_limit": "4g", "batch_size": 64, "cpus": 3, "learning_rate": 0.02}, + {"mem_limit": "5g", "batch_size": 128, "cpus": 2.5, "learning_rate": 0.09}, + # Add or modify the configurations depending on your host machine + ] + + docker_compose_content = f""" +version: '3' +services: + prometheus: + image: prom/prometheus:latest + container_name: prometheus + ports: + - 9090:9090 + deploy: + restart_policy: + condition: on-failure + command: + - --config.file=/etc/prometheus/prometheus.yml + volumes: + - ./config/prometheus.yml:/etc/prometheus/prometheus.yml:ro + depends_on: + - cadvisor + + cadvisor: + image: gcr.io/cadvisor/cadvisor:v0.47.0 + container_name: cadvisor + privileged: true + deploy: + restart_policy: + condition: on-failure + ports: + - "8080:8080" + volumes: + - /:/rootfs:ro + - /var/run:/var/run:ro + - /sys:/sys:ro + - /var/lib/docker/:/var/lib/docker:ro + - /dev/disk/:/dev/disk:ro + - /var/run/docker.sock:/var/run/docker.sock + + grafana: + image: grafana/grafana:latest + container_name: grafana + ports: + - 3000:3000 + deploy: + restart_policy: + condition: on-failure + volumes: + - grafana-storage:/var/lib/grafana + - ./config/grafana.ini:/etc/grafana/grafana.ini + - ./config/provisioning/datasources:/etc/grafana/provisioning/datasources + - ./config/provisioning/dashboards:/etc/grafana/provisioning/dashboards + depends_on: + - prometheus + - cadvisor + command: + - --config=/etc/grafana/grafana.ini + + + server: + container_name: server + build: + context: . + dockerfile: Dockerfile + command: python server.py --number_of_rounds={args.num_rounds} + environment: + FLASK_RUN_PORT: 6000 + DOCKER_HOST_IP: host.docker.internal + volumes: + - .:/app + - /var/run/docker.sock:/var/run/docker.sock + ports: + - "6000:6000" + - "8265:8265" + - "8000:8000" + depends_on: + - prometheus + - grafana +""" + # Add client services + for i in range(1, args.total_clients + 1): + if args.random: + config = random.choice(client_configs) + else: + config = client_configs[(i - 1) % len(client_configs)] + docker_compose_content += f""" + client{i}: + container_name: client{i} + build: + context: . + dockerfile: Dockerfile + command: python client.py --server_address=server:8080 --data_percentage={args.data_percentage} --client_id={i} --total_clients={args.total_clients} --batch_size={config["batch_size"]} --learning_rate={config["learning_rate"]} + deploy: + resources: + limits: + cpus: "{(config['cpus'])}" + memory: "{config['mem_limit']}" + volumes: + - .:/app + - /var/run/docker.sock:/var/run/docker.sock + ports: + - "{6000 + i}:{6000 + i}" + depends_on: + - server + environment: + FLASK_RUN_PORT: {6000 + i} + container_name: client{i} + DOCKER_HOST_IP: host.docker.internal +""" + + docker_compose_content += "volumes:\n grafana-storage:\n" + + with open("docker-compose.yml", "w") as file: + file.write(docker_compose_content) + + +if __name__ == "__main__": + args = parser.parse_args() + create_docker_compose(args) diff --git a/examples/flower-via-docker-compose/helpers/load_data.py b/examples/flower-via-docker-compose/helpers/load_data.py new file mode 100644 index 000000000000..1f2784946868 --- /dev/null +++ b/examples/flower-via-docker-compose/helpers/load_data.py @@ -0,0 +1,37 @@ +import numpy as np +import tensorflow as tf +from flwr_datasets import FederatedDataset +import logging + +logging.basicConfig(level=logging.INFO) # Configure logging +logger = logging.getLogger(__name__) # Create logger for the module + + +def load_data(data_sampling_percentage=0.5, client_id=1, total_clients=2): + """Load federated dataset partition based on client ID. + + Args: + data_sampling_percentage (float): Percentage of the dataset to use for training. + client_id (int): Unique ID for the client. + total_clients (int): Total number of clients. + + Returns: + Tuple of arrays: `(x_train, y_train), (x_test, y_test)`. + """ + + # Download and partition dataset + fds = FederatedDataset(dataset="cifar10", partitioners={"train": total_clients}) + partition = fds.load_partition(client_id - 1, "train") + partition.set_format("numpy") + + # Divide data on each client: 80% train, 20% test + partition = partition.train_test_split(test_size=0.2) + x_train, y_train = partition["train"]["img"] / 255.0, partition["train"]["label"] + x_test, y_test = partition["test"]["img"] / 255.0, partition["test"]["label"] + + # Apply data sampling + num_samples = int(data_sampling_percentage * len(x_train)) + indices = np.random.choice(len(x_train), num_samples, replace=False) + x_train, y_train = x_train[indices], y_train[indices] + + return (x_train, y_train), (x_test, y_test) diff --git a/examples/flower-via-docker-compose/model/model.py b/examples/flower-via-docker-compose/model/model.py new file mode 100644 index 000000000000..ab26d089b858 --- /dev/null +++ b/examples/flower-via-docker-compose/model/model.py @@ -0,0 +1,18 @@ +import tensorflow as tf + + +# Class for the model. In this case, we are using the MobileNetV2 model from Keras +class Model: + def __init__(self, learning_rate): + self.learning_rate = learning_rate + self.loss_function = tf.keras.losses.SparseCategoricalCrossentropy() + self.model = tf.keras.applications.MobileNetV2( + (32, 32, 3), alpha=0.1, classes=10, weights=None + ) + self.optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate) + + def compile(self): + self.model.compile(self.optimizer, self.loss_function, metrics=["accuracy"]) + + def get_model(self): + return self.model diff --git a/examples/flower-via-docker-compose/requirements.txt b/examples/flower-via-docker-compose/requirements.txt new file mode 100644 index 000000000000..b93e5b1d9f2b --- /dev/null +++ b/examples/flower-via-docker-compose/requirements.txt @@ -0,0 +1,5 @@ +flwr==1.7.0 +tensorflow==2.13.1 +numpy==1.24.3 +prometheus_client == 0.19.0 +flwr_datasets[vision] == 0.0.2 diff --git a/examples/flower-via-docker-compose/server.py b/examples/flower-via-docker-compose/server.py new file mode 100644 index 000000000000..99d1a7ef7399 --- /dev/null +++ b/examples/flower-via-docker-compose/server.py @@ -0,0 +1,47 @@ +import argparse +import flwr as fl +import logging +from strategy.strategy import FedCustom +from prometheus_client import start_http_server, Gauge + +# Initialize Logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Define a gauge to track the global model accuracy +accuracy_gauge = Gauge("model_accuracy", "Current accuracy of the global model") + +# Define a gauge to track the global model loss +loss_gauge = Gauge("model_loss", "Current loss of the global model") + +# Parse command line arguments +parser = argparse.ArgumentParser(description="Flower Server") +parser.add_argument( + "--number_of_rounds", + type=int, + default=100, + help="Number of FL rounds (default: 100)", +) +args = parser.parse_args() + + +# Function to Start Federated Learning Server +def start_fl_server(strategy, rounds): + try: + fl.server.start_server( + server_address="0.0.0.0:8080", + config=fl.server.ServerConfig(num_rounds=rounds), + strategy=strategy, + ) + except Exception as e: + logger.error(f"FL Server error: {e}", exc_info=True) + + +# Main Function +if __name__ == "__main__": + # Start Prometheus Metrics Server + start_http_server(8000) + + # Initialize Strategy Instance and Start FL Server + strategy_instance = FedCustom(accuracy_gauge=accuracy_gauge, loss_gauge=loss_gauge) + start_fl_server(strategy=strategy_instance, rounds=args.number_of_rounds) diff --git a/examples/flower-via-docker-compose/strategy/strategy.py b/examples/flower-via-docker-compose/strategy/strategy.py new file mode 100644 index 000000000000..9471a99f037f --- /dev/null +++ b/examples/flower-via-docker-compose/strategy/strategy.py @@ -0,0 +1,60 @@ +from typing import Dict, List, Optional, Tuple, Union +from flwr.common import Scalar, EvaluateRes +from flwr.server.client_proxy import ClientProxy +from flwr.server.strategy.aggregate import aggregate, weighted_loss_avg +import flwr as fl +import logging +from prometheus_client import Gauge + +logging.basicConfig(level=logging.INFO) # Configure logging +logger = logging.getLogger(__name__) # Create logger for the module + + +class FedCustom(fl.server.strategy.FedAvg): + def __init__( + self, accuracy_gauge: Gauge = None, loss_gauge: Gauge = None, *args, **kwargs + ): + super().__init__(*args, **kwargs) + + self.accuracy_gauge = accuracy_gauge + self.loss_gauge = loss_gauge + + def __repr__(self) -> str: + return "FedCustom" + + def aggregate_evaluate( + self, + server_round: int, + results: List[Tuple[ClientProxy, EvaluateRes]], + failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> Tuple[Optional[float], Dict[str, Scalar]]: + """Aggregate evaluation losses and accuracy using weighted average.""" + + if not results: + return None, {} + + # Calculate weighted average for loss using the provided function + loss_aggregated = weighted_loss_avg( + [ + (evaluate_res.num_examples, evaluate_res.loss) + for _, evaluate_res in results + ] + ) + + # Calculate weighted average for accuracy + accuracies = [ + evaluate_res.metrics["accuracy"] * evaluate_res.num_examples + for _, evaluate_res in results + ] + examples = [evaluate_res.num_examples for _, evaluate_res in results] + accuracy_aggregated = ( + sum(accuracies) / sum(examples) if sum(examples) != 0 else 0 + ) + + # Update the Prometheus gauges with the latest aggregated accuracy and loss values + self.accuracy_gauge.set(accuracy_aggregated) + self.loss_gauge.set(loss_aggregated) + + metrics_aggregated = {"loss": loss_aggregated, "accuracy": accuracy_aggregated} + + return loss_aggregated, metrics_aggregated diff --git a/examples/ios/FLiOS/CoreMLProto/NeuralNetwork.pb.swift b/examples/ios/FLiOS/CoreMLProto/NeuralNetwork.pb.swift index 6075e689cb1e..3348ceae9e0e 100644 --- a/examples/ios/FLiOS/CoreMLProto/NeuralNetwork.pb.swift +++ b/examples/ios/FLiOS/CoreMLProto/NeuralNetwork.pb.swift @@ -4308,7 +4308,7 @@ struct CoreML_Specification_ConvolutionLayerParams { /// The output blob's shape is ``[batch, channelsOut, depthOut, heightOut, widthOut]``. /// /// Type of padding can be ``custom``, ``valid``, or ``same``. Padded values are all zeros. -/// Output spatial dimensions depend on the the type of padding. For details, refer to the +/// Output spatial dimensions depend on the type of padding. For details, refer to the /// descriptions of the ``PaddingType`` field of this ``Convolution3DLayerParams`` message. /// /// Example @@ -5219,7 +5219,7 @@ struct CoreML_Specification_Pooling3DLayerParams { /// SAME padding adds enough padding to each dimension such that the output /// has the same spatial dimensions as the input. Padding is added evenly to both /// sides of each dimension unless the total padding to add is odd, in which case the extra padding - /// is added to the back/bottom/right side of the respective dimension. For example, if the the + /// is added to the back/bottom/right side of the respective dimension. For example, if the /// total horizontal padding is 3, then there will be 1 padding on the left, and 2 padding on the right. enum Pooling3DPaddingType: SwiftProtobuf.Enum { typealias RawValue = Int @@ -9493,7 +9493,7 @@ struct CoreML_Specification_ConstantPaddingLayerParams { ///* /// Length of this repeated field must be twice the rank of the first input. - /// 2*i-th and (2*i+1)-th values represent the amount of padding to be applied to the the i-th input + /// 2*i-th and (2*i+1)-th values represent the amount of padding to be applied to the i-th input /// dimension, "before" and "after" the input values, respectively. var padAmounts: [UInt64] = [] diff --git a/examples/ios/FLiOS/CoreMLProto/NeuralNetwork.proto b/examples/ios/FLiOS/CoreMLProto/NeuralNetwork.proto index 6b2ebb1c8ba1..e80a3566c2c3 100644 --- a/examples/ios/FLiOS/CoreMLProto/NeuralNetwork.proto +++ b/examples/ios/FLiOS/CoreMLProto/NeuralNetwork.proto @@ -1554,7 +1554,7 @@ message ConvolutionLayerParams { * The output blob's shape is ``[batch, channelsOut, depthOut, heightOut, widthOut]``. * * Type of padding can be ``custom``, ``valid``, or ``same``. Padded values are all zeros. - * Output spatial dimensions depend on the the type of padding. For details, refer to the + * Output spatial dimensions depend on the type of padding. For details, refer to the * descriptions of the ``PaddingType`` field of this ``Convolution3DLayerParams`` message. * * Example @@ -2056,7 +2056,7 @@ message Pooling3DLayerParams { * SAME padding adds enough padding to each dimension such that the output * has the same spatial dimensions as the input. Padding is added evenly to both * sides of each dimension unless the total padding to add is odd, in which case the extra padding - * is added to the back/bottom/right side of the respective dimension. For example, if the the + * is added to the back/bottom/right side of the respective dimension. For example, if the * total horizontal padding is 3, then there will be 1 padding on the left, and 2 padding on the right. */ enum Pooling3DPaddingType { @@ -4779,7 +4779,7 @@ enum ScatterMode { * * The output has the same shape as the first input. * "indices" input must have rank less than or equal to the "updates" input and its shape - * must be a subset of the the shape of the "updates" input. + * must be a subset of the shape of the "updates" input. * * e.g: * @@ -5064,7 +5064,7 @@ message ConstantPaddingLayerParams { /** * Length of this repeated field must be twice the rank of the first input. - * 2*i-th and (2*i+1)-th values represent the amount of padding to be applied to the the i-th input + * 2*i-th and (2*i+1)-th values represent the amount of padding to be applied to the i-th input * dimension, "before" and "after" the input values, respectively. */ repeated uint64 padAmounts = 2; diff --git a/examples/ios/pyproject.toml b/examples/ios/pyproject.toml index c1bdbb815bd5..2e55b14cf761 100644 --- a/examples/ios/pyproject.toml +++ b/examples/ios/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" name = "flwr_ios_coreml" version = "0.1.0" description = "Example Server for Flower iOS/CoreML" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" diff --git a/examples/llm-flowertune/README.md b/examples/llm-flowertune/README.md new file mode 100644 index 000000000000..60e183d2a9c0 --- /dev/null +++ b/examples/llm-flowertune/README.md @@ -0,0 +1,139 @@ +# Federated Large Language Model (LLM) Fine-tuning with Flower + +Large language models (LLMs), which have been trained on vast amounts of publicly accessible data, have shown remarkable effectiveness in a wide range of areas. +However, despite the fact that more data typically leads to improved performance, there is a concerning prospect that the supply of high-quality public data will deplete within a few years. +Federated LLM training could unlock access to an endless pool of distributed private data by allowing multiple data owners to collaboratively train a shared model without the need to exchange raw data. + +This introductory example conducts federated instruction tuning with pretrained [LLama2](https://huggingface.co/openlm-research) models on [Alpaca-GPT4](https://huggingface.co/datasets/vicgalle/alpaca-gpt4) dataset. +We use [Flower Datasets](https://flower.dev/docs/datasets/) to download, partition and preprocess the dataset. +The fine-tuning is done using the [🤗PEFT](https://huggingface.co/docs/peft/en/index) library. +We use Flower's Simulation Engine to simulate the LLM fine-tuning process in federated way, +which allows users to perform the training on a single GPU. + +## Environments Setup + +Start by cloning the code example. We prepared a single-line command that you can copy into your shell which will checkout the example for you: + +```shell +git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/llm-flowertune . && rm -rf flower && cd llm-flowertune +``` + +This will create a new directory called `llm-flowertune` containing the following files: + +``` +-- README.md <- Your're reading this right now +-- main.py <- Start fed-LLM simulation +-- client.py <- Flower client constructor +-- model.py <- Model build +-- dataset.py <- Dataset and tokenizer build +-- utils.py <- Utility functions +-- test.py <- Test pre-trained model +-- app.py <- ServerApp/ClientApp for Flower-Next +-- conf/config.yaml <- Configuration file +-- requirements.txt <- Example dependencies +``` + +### Installing dependencies + +Project dependencies are defined in `requirements.txt`. Install them with: + +```shell +pip install -r requirements.txt +``` + +## Run LLM Fine-tuning + +With an activated Python environment, run the example with default config values. The config is in `conf/config.yaml` and is loaded automatically. + +```bash +# Run with default config +python main.py +``` + +This command will run FL simulations with a 4-bit [OpenLLaMA 7Bv2](https://huggingface.co/openlm-research/open_llama_7b_v2) model involving 2 clients per rounds for 100 FL rounds. You can override configuration parameters directly from the command line. Below are a few settings you might want to test: + +```bash +# Use OpenLLaMA-3B instead of 7B and 8-bits quantization +python main.py model.name="openlm-research/open_llama_3b_v2" model.quantization=8 + +# Run for 50 rounds but increasing the fraction of clients that participate per round to 25% +python main.py num_rounds=50 fraction_fit.fraction_fit=0.25 +``` + +## Expected Results + +![](_static/train_loss_smooth.png) + +As expected, LLama2-7B model works better than its 3B version with lower training loss. With the hyperparameters tested, the 8-bit model seems to deliver lower training loss for the smaller 3B model compared to its 4-bit version. + +You can run all 8 experiments with a single command as: + +```bash +python main.py --multirun model.name="openlm-research/open_llama_7b_v2","openlm-research/open_llama_3b_v2" model.quantization=8,4 strategy.fraction_fit=0.1,0.2 +``` + +## VRAM Consumption + +| Models | 7-billion (8-bit) | 7-billion (4-bit) | 3-billion (8-bit) | 3-billion (4-bit) | +| :----: | :---------------: | :---------------: | :---------------: | :---------------: | +| VRAM | ~22.00 GB | ~16.50 GB | ~13.50 GB | ~10.60 GB | + +We make use of the [bitsandbytes](https://huggingface.co/docs/bitsandbytes/main/en/index) library in conjunction with [PEFT](https://huggingface.co/docs/peft/en/index) to derive LLMs that can be fine-tuned efficiently. +The above table shows the VRAM consumption per client for the different models considered in this example. +You can adjust the CPU/GPU resources you assign to each of the clients based on your device. +For example, it is easy to train 2 concurrent clients on each GPU (24 GB VRAM) if you choose 3-billion (4-bit) model. + +```bash +# This will assign 50% of the GPU's VRAM to each client. +python main.py model.name="openlm-research/open_llama_3b_v2" model.quantization=4 client_resources.num_gpus=0.5 +``` + +## Test with your Questions + +We provide a script to test your trained model by passing your specified questions. For example: + +```bash +python test.py --peft-path=/path/to/trained-model-dir/ \ + --question="What is the ideal 1-day plan in London?" +``` + +An answer generated from federated trained 7-billion (8-bit) LLama2 model: + +``` +Great choice. +London has so much to offer, and you can really soak up all the sights and sounds in just a single day. +Here's a suggested itinerary for you. +Start your day off with a hearty breakfast at an authentic British diner. +Then head to the iconic Big Ben and the Houses of Parliament to learn about the history of the city. +Next, make your way to Westminster Abbey to see the many historical monuments and memorials. +From there, cross the river Thames to the Tower of London, which is home to the Crown Jewels of England and Scotland. +Finally, end your day with a relaxing visit to the London Eye, the tallest Ferris wheel in Europe, for a beautiful view of the city. +``` + +The [`Vicuna`](https://huggingface.co/lmsys/vicuna-13b-v1.1) template we used in this example is for a chat assistant. +The generated answer is expected to be a multi-turn conversations. Feel free to try more interesting questions! + +## Run with Flower Next (preview) + +We conduct a 2-client setting to demonstrate how to run federated LLM fine-tuning with Flower Next. +Please follow the steps below: + +1. Start the long-running Flower server (SuperLink) + ```bash + flower-superlink --insecure + ``` +2. Start the long-running Flower client (SuperNode) + ```bash + # In a new terminal window, start the first long-running Flower client: + flower-client-app app:client1 --insecure + ``` + ```bash + # In another new terminal window, start the second long-running Flower client: + flower-client-app app:client2 --insecure + ``` +3. Run the Flower App + ```bash + # With both the long-running server (SuperLink) and two clients (SuperNode) up and running, + # we can now run the actual Flower App: + flower-server-app app:server --insecure + ``` diff --git a/examples/llm-flowertune/_static/train_loss_smooth.png b/examples/llm-flowertune/_static/train_loss_smooth.png new file mode 100644 index 000000000000..02034e6e9eb5 Binary files /dev/null and b/examples/llm-flowertune/_static/train_loss_smooth.png differ diff --git a/examples/llm-flowertune/app.py b/examples/llm-flowertune/app.py new file mode 100644 index 000000000000..e04ad8715de6 --- /dev/null +++ b/examples/llm-flowertune/app.py @@ -0,0 +1,86 @@ +import os +import warnings +from hydra import compose, initialize + +import flwr as fl +from flwr_datasets import FederatedDataset + +from dataset import get_tokenizer_and_data_collator_and_propt_formatting +from client import gen_client_fn +from utils import get_on_fit_config, fit_weighted_average + + +warnings.filterwarnings("ignore", category=UserWarning) + +NUM_ROUNDS = 100 +save_path = "./results/" + +with initialize(config_path="conf"): + cfg = compose(config_name="config") + +# Reset the number of number +cfg.num_rounds = NUM_ROUNDS +cfg.train.num_rounds = NUM_ROUNDS + +# Create output directory +if not os.path.exists(save_path): + os.mkdir(save_path) + +# Partition dataset and get dataloaders +# We set the number of partitions to 20 for fast processing. +fds = FederatedDataset( + dataset=cfg.dataset.name, partitioners={"train": cfg.num_clients} +) +( + tokenizer, + data_collator, + formatting_prompts_func, +) = get_tokenizer_and_data_collator_and_propt_formatting(cfg.model.name) + + +# ClientApp for client #1 (Flower Next) +client1 = fl.client.ClientApp( + client_fn=gen_client_fn( + fds, + tokenizer, + formatting_prompts_func, + data_collator, + cfg.model, + cfg.train, + save_path, + partition_id=0, + api=True, + ), +) + + +# ClientApp for client #2 (Flower Next) +client2 = fl.client.ClientApp( + client_fn=gen_client_fn( + fds, + tokenizer, + formatting_prompts_func, + data_collator, + cfg.model, + cfg.train, + save_path, + partition_id=1, + api=True, + ), +) + + +# Instantiate strategy. +strategy = fl.server.strategy.FedAvg( + min_available_clients=2, # Simulate a 2-client setting + fraction_fit=1.0, + fraction_evaluate=0.0, # no client evaluation + on_fit_config_fn=get_on_fit_config(), + fit_metrics_aggregation_fn=fit_weighted_average, +) + +# ServerApp for Flower-Next +server = fl.server.ServerApp( + config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS), + strategy=strategy, +) diff --git a/examples/llm-flowertune/client.py b/examples/llm-flowertune/client.py new file mode 100644 index 000000000000..28b324ba5bf1 --- /dev/null +++ b/examples/llm-flowertune/client.py @@ -0,0 +1,129 @@ +from collections import OrderedDict +from typing import Callable, Dict, Tuple + +import flwr as fl +import torch +from flwr.common.typing import NDArrays, Scalar +from omegaconf import DictConfig +from trl import SFTTrainer +from transformers import TrainingArguments +from peft import get_peft_model_state_dict, set_peft_model_state_dict + +from models import get_model, cosine_annealing + + +# pylint: disable=too-many-arguments +class FlowerClient( + fl.client.NumPyClient +): # pylint: disable=too-many-instance-attributes + """Standard Flower client for CNN training.""" + + def __init__( + self, + model_cfg: DictConfig, + train_cfg: DictConfig, + trainset, + tokenizer, + formatting_prompts_func, + data_collator, + save_path, + ): # pylint: disable=too-many-arguments + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self.train_cfg = train_cfg + self.training_argumnets = TrainingArguments(**train_cfg.training_arguments) + self.tokenizer = tokenizer + self.formatting_prompts_func = formatting_prompts_func + self.data_collator = data_collator + self.save_path = save_path + + # instantiate model + self.model = get_model(model_cfg) + + self.trainset = trainset + + def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays: + """Return the parameters of the current net.""" + + state_dict = get_peft_model_state_dict(self.model) + return [val.cpu().numpy() for _, val in state_dict.items()] + + def fit( + self, parameters: NDArrays, config: Dict[str, Scalar] + ) -> Tuple[NDArrays, int, Dict]: + """Implement distributed fit function for a given client.""" + set_parameters(self.model, parameters) + + new_lr = cosine_annealing( + int(config["current_round"]), + self.train_cfg.num_rounds, + self.train_cfg.learning_rate_max, + self.train_cfg.learning_rate_min, + ) + + self.training_argumnets.learning_rate = new_lr + self.training_argumnets.output_dir = self.save_path + + # Construct trainer + trainer = SFTTrainer( + model=self.model, + tokenizer=self.tokenizer, + args=self.training_argumnets, + max_seq_length=self.train_cfg.seq_length, + train_dataset=self.trainset, + formatting_func=self.formatting_prompts_func, + data_collator=self.data_collator, + ) + + # Do local training + results = trainer.train() + + return ( + self.get_parameters({}), + len(self.trainset), + {"train_loss": results.training_loss}, + ) + + +def set_parameters(model, parameters: NDArrays) -> None: + """Change the parameters of the model using the given ones.""" + peft_state_dict_keys = get_peft_model_state_dict(model).keys() + params_dict = zip(peft_state_dict_keys, parameters) + state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict}) + set_peft_model_state_dict(model, state_dict) + + +def gen_client_fn( + fds, + tokenizer, + formatting_prompts_func, + data_collator, + model_cfg: DictConfig, + train_cfg: DictConfig, + save_path: str, + partition_id: int = 0, + api: bool = False, +) -> Callable[[str], FlowerClient]: # pylint: disable=too-many-arguments + """Generate the client function that creates the Flower Clients.""" + + def client_fn(cid: str) -> FlowerClient: + """Create a Flower client representing a single organization.""" + + # Let's get the partition corresponding to the i-th client + client_trainset = ( + fds.load_partition(partition_id, "train") + if api + else fds.load_partition(int(cid), "train") + ) + client_trainset = client_trainset.rename_column("output", "response") + + return FlowerClient( + model_cfg, + train_cfg, + client_trainset, + tokenizer, + formatting_prompts_func, + data_collator, + save_path, + ).to_client() + + return client_fn diff --git a/examples/llm-flowertune/conf/config.yaml b/examples/llm-flowertune/conf/config.yaml new file mode 100644 index 000000000000..0b769d351479 --- /dev/null +++ b/examples/llm-flowertune/conf/config.yaml @@ -0,0 +1,45 @@ +# Federated Instruction Tuning on General Dataset +--- + +num_clients: 20 # total number of clients +num_rounds: 100 + +dataset: + name: "vicgalle/alpaca-gpt4" + +model: + name: "openlm-research/open_llama_7b_v2" + quantization: 4 # 8 or 4 if you want to do quantization with BitsAndBytes + gradient_checkpointing: True + lora: + peft_lora_r: 32 + peft_lora_alpha: 64 + +train: + num_rounds: ${num_rounds} + save_every_round: 5 + learning_rate_max: 5e-5 + learning_rate_min: 1e-6 + seq_length: 512 + training_arguments: + output_dir: null # to be set by hydra + learning_rate: null # to be set by the client + per_device_train_batch_size: 16 + gradient_accumulation_steps: 1 + logging_steps: 10 + num_train_epochs: 3 + max_steps: 10 + report_to: null + save_steps: 1000 + save_total_limit: 10 + gradient_checkpointing: ${model.gradient_checkpointing} + lr_scheduler_type: "constant" + +strategy: + _target_: flwr.server.strategy.FedAvg + fraction_fit: 0.1 # sample 10% of clients (i.e. 2 per round) + fraction_evaluate: 0.0 # no client evaluation + +client_resources: + num_cpus: 8 + num_gpus: 1.0 diff --git a/examples/llm-flowertune/dataset.py b/examples/llm-flowertune/dataset.py new file mode 100644 index 000000000000..571be31f7fba --- /dev/null +++ b/examples/llm-flowertune/dataset.py @@ -0,0 +1,29 @@ +from transformers import AutoTokenizer +from trl import DataCollatorForCompletionOnlyLM + + +def formatting_prompts_func(example): + output_texts = [] + # Constructing a standard Alpaca (https://github.com/tatsu-lab/stanford_alpaca#data-release) prompt + mssg = "Below is an instruction that describes a task. Write a response that appropriately completes the request." + for i in range(len(example["instruction"])): + text = f"{mssg}\n### Instruction:\n{example['instruction'][i]}\n### Response: {example['response'][i]}" + output_texts.append(text) + return output_texts + + +def get_tokenizer_and_data_collator_and_propt_formatting(model_name: str): + # From: https://huggingface.co/docs/trl/en/sft_trainer + tokenizer = AutoTokenizer.from_pretrained( + model_name, use_fast=True, padding_side="right" + ) + tokenizer.pad_token = tokenizer.eos_token + response_template_with_context = "\n### Response:" # alpaca response tag + response_template_ids = tokenizer.encode( + response_template_with_context, add_special_tokens=False + )[2:] + data_collator = DataCollatorForCompletionOnlyLM( + response_template_ids, tokenizer=tokenizer + ) + + return tokenizer, data_collator, formatting_prompts_func diff --git a/examples/llm-flowertune/main.py b/examples/llm-flowertune/main.py new file mode 100644 index 000000000000..2d03e9cbcae5 --- /dev/null +++ b/examples/llm-flowertune/main.py @@ -0,0 +1,94 @@ +import warnings +import pickle + +import flwr as fl +from flwr_datasets import FederatedDataset + +import hydra +from hydra.core.hydra_config import HydraConfig +from hydra.utils import instantiate +from omegaconf import DictConfig, OmegaConf + +from dataset import get_tokenizer_and_data_collator_and_propt_formatting +from utils import get_on_fit_config, fit_weighted_average, get_evaluate_fn +from client import gen_client_fn + + +warnings.filterwarnings("ignore", category=UserWarning) + + +@hydra.main(config_path="conf", config_name="config", version_base=None) +def main(cfg: DictConfig) -> None: + """Run federated LLM fine-tuning. + + Parameters + ---------- + cfg : DictConfig + An omegaconf object that stores the hydra config. + """ + # Print config structured as YAML + print(OmegaConf.to_yaml(cfg)) + + # Partition dataset and get dataloaders + fds = FederatedDataset( + dataset=cfg.dataset.name, partitioners={"train": cfg.num_clients} + ) + ( + tokenizer, + data_collator, + formatting_prompts_func, + ) = get_tokenizer_and_data_collator_and_propt_formatting( + cfg.model.name, + ) + + # Hydra automatically creates an output directory + # Let's retrieve it and save some results there + save_path = HydraConfig.get().runtime.output_dir + + # Prepare function that will be used to spawn each client + client_fn = gen_client_fn( + fds, + tokenizer, + formatting_prompts_func, + data_collator, + cfg.model, + cfg.train, + save_path, + ) + + # Instantiate strategy according to config. Here we pass other arguments + # that are only defined at run time. + strategy = instantiate( + cfg.strategy, + on_fit_config_fn=get_on_fit_config(), + fit_metrics_aggregation_fn=fit_weighted_average, + evaluate_fn=get_evaluate_fn( + cfg.model, cfg.train.save_every_round, cfg.num_rounds, save_path + ), + ) + + # Start simulation + history = fl.simulation.start_simulation( + client_fn=client_fn, + num_clients=cfg.num_clients, + config=fl.server.ServerConfig(num_rounds=cfg.num_rounds), + client_resources={ + "num_cpus": cfg.client_resources.num_cpus, + "num_gpus": cfg.client_resources.num_gpus, + }, + strategy=strategy, + ) + + # Experiment completed. Now we save the results and + # generate plots using the `history` + print("................") + print(history) + + # Save results as a Python pickle using a file_path + # the directory created by Hydra for each run + with open(f"{save_path}/results.pkl", "wb") as f: + pickle.dump(history, f) + + +if __name__ == "__main__": + main() diff --git a/examples/llm-flowertune/models.py b/examples/llm-flowertune/models.py new file mode 100644 index 000000000000..78eef75d10d2 --- /dev/null +++ b/examples/llm-flowertune/models.py @@ -0,0 +1,56 @@ +import torch +from omegaconf import DictConfig +from transformers import AutoModelForCausalLM +from transformers import BitsAndBytesConfig +from peft import get_peft_model, LoraConfig +from peft.utils import prepare_model_for_kbit_training + +import math + + +def cosine_annealing( + current_round: int, + total_round: int, + lrate_max: float = 0.001, + lrate_min: float = 0.0, +) -> float: + """Implement cosine annealing learning rate schedule.""" + + cos_inner = math.pi * current_round / total_round + return lrate_min + 0.5 * (lrate_max - lrate_min) * (1 + math.cos(cos_inner)) + + +def get_model(model_cfg: DictConfig): + """Load model with appropriate quantization config and other optimizations. + + Please refer to this example for `peft + BitsAndBytes`: + https://github.com/huggingface/peft/blob/main/examples/fp4_finetuning/finetune_fp4_opt_bnb_peft.py + """ + + if model_cfg.quantization == 4: + quantization_config = BitsAndBytesConfig(load_in_4bit=True) + elif model_cfg.quantization == 8: + quantization_config = BitsAndBytesConfig(load_in_8bit=True) + else: + raise ValueError( + f"Use 4-bit or 8-bit quantization. You passed: {model_cfg.quantization}/" + ) + + model = AutoModelForCausalLM.from_pretrained( + model_cfg.name, + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + ) + + model = prepare_model_for_kbit_training( + model, use_gradient_checkpointing=model_cfg.gradient_checkpointing + ) + + peft_config = LoraConfig( + r=model_cfg.lora.peft_lora_r, + lora_alpha=model_cfg.lora.peft_lora_alpha, + lora_dropout=0.075, + task_type="CAUSAL_LM", + ) + + return get_peft_model(model, peft_config) diff --git a/examples/llm-flowertune/requirements.txt b/examples/llm-flowertune/requirements.txt new file mode 100644 index 000000000000..e557dbfc2ff8 --- /dev/null +++ b/examples/llm-flowertune/requirements.txt @@ -0,0 +1,8 @@ +flwr-nightly[rest,simulation] +flwr_datasets==0.0.2 +hydra-core==1.3.2 +trl==0.7.2 +bitsandbytes==0.40.2 +scipy==1.11.2 +peft==0.4.0 +fschat[model_worker,webui]==0.2.35 diff --git a/examples/llm-flowertune/test.py b/examples/llm-flowertune/test.py new file mode 100644 index 000000000000..652bb9aafcf5 --- /dev/null +++ b/examples/llm-flowertune/test.py @@ -0,0 +1,73 @@ +# This python file is adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/llm_judge/gen_model_answer.py + +import argparse +import torch +from peft import AutoPeftModelForCausalLM +from transformers import AutoTokenizer +from fastchat.conversation import get_conv_template + + +parser = argparse.ArgumentParser() +parser.add_argument("--peft-path", type=str, default=None) +parser.add_argument("--question", type=str, default="How are you") +parser.add_argument("--template", type=str, default="vicuna_v1.1") +args = parser.parse_args() + +# Load model and tokenizer +model = AutoPeftModelForCausalLM.from_pretrained( + args.peft_path, torch_dtype=torch.float16 +).to("cuda") +base_model = model.peft_config["default"].base_model_name_or_path +tokenizer = AutoTokenizer.from_pretrained(base_model) + +# Generate answers +temperature = 0.7 +choices = [] +conv = get_conv_template(args.template) + +conv.append_message(conv.roles[0], args.question) +conv.append_message(conv.roles[1], None) +prompt = conv.get_prompt() +input_ids = tokenizer([prompt]).input_ids + +output_ids = model.generate( + input_ids=torch.as_tensor(input_ids).cuda(), + do_sample=True, + temperature=temperature, + max_new_tokens=1024, +) + +output_ids = ( + output_ids[0] + if model.config.is_encoder_decoder + else output_ids[0][len(input_ids[0]) :] +) + +# Be consistent with the template's stop_token_ids +if conv.stop_token_ids: + stop_token_ids_index = [ + i for i, id in enumerate(output_ids) if id in conv.stop_token_ids + ] + if len(stop_token_ids_index) > 0: + output_ids = output_ids[: stop_token_ids_index[0]] + +output = tokenizer.decode( + output_ids, + spaces_between_special_tokens=False, +) + +if conv.stop_str and output.find(conv.stop_str) > 0: + output = output[: output.find(conv.stop_str)] + +for special_token in tokenizer.special_tokens_map.values(): + if isinstance(special_token, list): + for special_tok in special_token: + output = output.replace(special_tok, "") + else: + output = output.replace(special_token, "") +output = output.strip() + +conv.update_last_message(output) + +print(f">> prompt: {prompt}") +print(f">> Generated: {output}") diff --git a/examples/llm-flowertune/utils.py b/examples/llm-flowertune/utils.py new file mode 100644 index 000000000000..bbb607810537 --- /dev/null +++ b/examples/llm-flowertune/utils.py @@ -0,0 +1,43 @@ +from client import set_parameters +from models import get_model + + +# Get function that will be executed by the strategy's evaluate() method +# Here we use it to save global model checkpoints +def get_evaluate_fn(model_cfg, save_every_round, total_round, save_path): + """Return an evaluation function for saving global model.""" + + def evaluate(server_round: int, parameters, config): + # Save model + if server_round != 0 and ( + server_round == total_round or server_round % save_every_round == 0 + ): + # Init model + model = get_model(model_cfg) + set_parameters(model, parameters) + + model.save_pretrained(f"{save_path}/peft_{server_round}") + + return 0.0, {} + + return evaluate + + +# Get a function that will be used to construct the config that the client's +# fit() method will receive +def get_on_fit_config(): + def fit_config_fn(server_round: int): + fit_config = {"current_round": server_round} + return fit_config + + return fit_config_fn + + +def fit_weighted_average(metrics): + """Aggregation function for (federated) evaluation metrics.""" + # Multiply accuracy of each client by number of examples used + losses = [num_examples * m["train_loss"] for num_examples, m in metrics] + examples = [num_examples for num_examples, _ in metrics] + + # Aggregate and return custom metric (weighted average) + return {"train_loss": sum(losses) / sum(examples)} diff --git a/examples/mt-pytorch-callable/README.md b/examples/mt-pytorch-callable/README.md deleted file mode 100644 index 120e28098344..000000000000 --- a/examples/mt-pytorch-callable/README.md +++ /dev/null @@ -1,49 +0,0 @@ -# Deploy 🧪 - -🧪 = this page covers experimental features that might change in future versions of Flower - -This how-to guide describes the deployment of a long-running Flower server. - -## Preconditions - -Let's assume the following project structure: - -```bash -$ tree . -. -└── client.py -├── driver.py -├── requirements.txt -``` - -## Install dependencies - -```bash -pip install -r requirements.txt -``` - -## Start the long-running Flower server - -```bash -flower-server --insecure -``` - -## Start the long-running Flower client - -In a new terminal window, start the first long-running Flower client: - -```bash -flower-client --insecure client:flower -``` - -In yet another new terminal window, start the second long-running Flower client: - -```bash -flower-client --insecure client:flower -``` - -## Start the Driver script - -```bash -python driver.py -``` diff --git a/examples/mt-pytorch-callable/client.py b/examples/mt-pytorch-callable/client.py deleted file mode 100644 index 4195a714ca89..000000000000 --- a/examples/mt-pytorch-callable/client.py +++ /dev/null @@ -1,123 +0,0 @@ -import warnings -from collections import OrderedDict - -import flwr as fl -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.utils.data import DataLoader -from torchvision.datasets import CIFAR10 -from torchvision.transforms import Compose, Normalize, ToTensor -from tqdm import tqdm - - -# ############################################################################# -# 1. Regular PyTorch pipeline: nn.Module, train, test, and DataLoader -# ############################################################################# - -warnings.filterwarnings("ignore", category=UserWarning) -DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - - -class Net(nn.Module): - """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" - - def __init__(self) -> None: - super(Net, self).__init__() - self.conv1 = nn.Conv2d(3, 6, 5) - self.pool = nn.MaxPool2d(2, 2) - self.conv2 = nn.Conv2d(6, 16, 5) - self.fc1 = nn.Linear(16 * 5 * 5, 120) - self.fc2 = nn.Linear(120, 84) - self.fc3 = nn.Linear(84, 10) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.pool(F.relu(self.conv1(x))) - x = self.pool(F.relu(self.conv2(x))) - x = x.view(-1, 16 * 5 * 5) - x = F.relu(self.fc1(x)) - x = F.relu(self.fc2(x)) - return self.fc3(x) - - -def train(net, trainloader, epochs): - """Train the model on the training set.""" - criterion = torch.nn.CrossEntropyLoss() - optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9) - for _ in range(epochs): - for images, labels in tqdm(trainloader): - optimizer.zero_grad() - criterion(net(images.to(DEVICE)), labels.to(DEVICE)).backward() - optimizer.step() - - -def test(net, testloader): - """Validate the model on the test set.""" - criterion = torch.nn.CrossEntropyLoss() - correct, loss = 0, 0.0 - with torch.no_grad(): - for images, labels in tqdm(testloader): - outputs = net(images.to(DEVICE)) - labels = labels.to(DEVICE) - loss += criterion(outputs, labels).item() - correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() - accuracy = correct / len(testloader.dataset) - return loss, accuracy - - -def load_data(): - """Load CIFAR-10 (training and test set).""" - trf = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) - trainset = CIFAR10("./data", train=True, download=True, transform=trf) - testset = CIFAR10("./data", train=False, download=True, transform=trf) - return DataLoader(trainset, batch_size=32, shuffle=True), DataLoader(testset) - - -# ############################################################################# -# 2. Federation of the pipeline with Flower -# ############################################################################# - -# Load model and data (simple CNN, CIFAR-10) -net = Net().to(DEVICE) -trainloader, testloader = load_data() - - -# Define Flower client -class FlowerClient(fl.client.NumPyClient): - def get_parameters(self, config): - return [val.cpu().numpy() for _, val in net.state_dict().items()] - - def set_parameters(self, parameters): - params_dict = zip(net.state_dict().keys(), parameters) - state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) - net.load_state_dict(state_dict, strict=True) - - def fit(self, parameters, config): - self.set_parameters(parameters) - train(net, trainloader, epochs=1) - return self.get_parameters(config={}), len(trainloader.dataset), {} - - def evaluate(self, parameters, config): - self.set_parameters(parameters) - loss, accuracy = test(net, testloader) - return loss, len(testloader.dataset), {"accuracy": accuracy} - - -def client_fn(cid: str): - """.""" - return FlowerClient().to_client() - - -# To run this: `flower-client client:flower` -flower = fl.flower.Flower( - client_fn=client_fn, -) - - -if __name__ == "__main__": - # Start Flower client - fl.client.start_client( - server_address="0.0.0.0:9092", - client=FlowerClient().to_client(), - transport="grpc-rere", - ) diff --git a/examples/mt-pytorch-callable/driver.py b/examples/mt-pytorch-callable/driver.py deleted file mode 100644 index 1248672b6813..000000000000 --- a/examples/mt-pytorch-callable/driver.py +++ /dev/null @@ -1,25 +0,0 @@ -from typing import List, Tuple - -import flwr as fl -from flwr.common import Metrics - - -# Define metric aggregation function -def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: - # Multiply accuracy of each client by number of examples used - accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics] - examples = [num_examples for num_examples, _ in metrics] - - # Aggregate and return custom metric (weighted average) - return {"accuracy": sum(accuracies) / sum(examples)} - - -# Define strategy -strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=weighted_average) - -# Start Flower driver -fl.driver.start_driver( - server_address="0.0.0.0:9091", - config=fl.server.ServerConfig(num_rounds=3), - strategy=strategy, -) diff --git a/examples/mt-pytorch-callable/requirements.txt b/examples/mt-pytorch-callable/requirements.txt deleted file mode 100644 index 797ca6db6244..000000000000 --- a/examples/mt-pytorch-callable/requirements.txt +++ /dev/null @@ -1,4 +0,0 @@ -flwr>=1.0, <2.0 -torch==1.13.1 -torchvision==0.14.1 -tqdm==4.65.0 diff --git a/examples/mt-pytorch/README.md b/examples/mt-pytorch/README.md deleted file mode 100644 index ef9516314e26..000000000000 --- a/examples/mt-pytorch/README.md +++ /dev/null @@ -1,73 +0,0 @@ -# Multi-Tenant Federated Learning with Flower and PyTorch - -This example contains experimental code. Please consult the regular PyTorch code examples ([quickstart](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch), [advanced](https://github.com/adap/flower/tree/main/examples/advanced-pytorch)) to learn how to use Flower with PyTorch. - -## Setup - -```bash -./dev/venv-reset.sh -``` - -## Run with Driver API - -Terminal 1: start Flower server - -```bash -flower-server -``` - -Terminal 2+3: start two Flower client nodes - -```bash -python client.py -``` - -Terminal 4: start Driver script - -Using: - -```bash -python start_driver.py -``` - -Or, alternatively: - -```bash -python driver.py -``` - -## Run in legacy mode - -Terminal 1: start Flower server - -```bash -python server.py -``` - -Terminal 2+3: start two clients - -```bash -python client.py -``` - -## Run with Driver API (REST transport layer) - -Terminal 1: start Flower server and enable REST transport layer - -```bash -flower-server --rest -``` - -Terminal 2: start Driver script - -```bash -python driver.py -``` - -Open file `client.py` adjust `server_address` and `transport`. - -Terminal 3+4: start two Flower client nodes - -```bash -python client.py -``` diff --git a/examples/mt-pytorch/client.py b/examples/mt-pytorch/client.py deleted file mode 100644 index 23cc736fd62b..000000000000 --- a/examples/mt-pytorch/client.py +++ /dev/null @@ -1,42 +0,0 @@ -from typing import Dict -import flwr as fl -from flwr.common import NDArrays, Scalar - -from task import ( - Net, - DEVICE, - load_data, - get_parameters, - set_parameters, - train, - test, -) - - -# Load model and data (simple CNN, CIFAR-10) -net = Net().to(DEVICE) -trainloader, testloader = load_data() - - -# Define Flower client -class FlowerClient(fl.client.NumPyClient): - def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays: - return get_parameters(net) - - def fit(self, parameters, config): - set_parameters(net, parameters) - results = train(net, trainloader, testloader, epochs=1, device=DEVICE) - return get_parameters(net), len(trainloader.dataset), results - - def evaluate(self, parameters, config): - set_parameters(net, parameters) - loss, accuracy = test(net, testloader) - return loss, len(testloader.dataset), {"accuracy": accuracy} - - -# Start Flower client -fl.client.start_numpy_client( - server_address="0.0.0.0:9092", # "0.0.0.0:9093" for REST - client=FlowerClient(), - transport="grpc-rere", # "rest" for REST -) diff --git a/examples/mt-pytorch/driver.py b/examples/mt-pytorch/driver.py deleted file mode 100644 index ad4d5e1caabe..000000000000 --- a/examples/mt-pytorch/driver.py +++ /dev/null @@ -1,226 +0,0 @@ -from typing import List, Tuple -import random -import time - -from flwr.driver import GrpcDriver -from flwr.common import ( - ServerMessage, - FitIns, - ndarrays_to_parameters, - serde, - parameters_to_ndarrays, - ClientMessage, - NDArrays, - Code, -) -from flwr.proto import driver_pb2, task_pb2, node_pb2, transport_pb2 -from flwr.server.strategy.aggregate import aggregate -from flwr.common import Metrics -from flwr.server import History -from flwr.common import serde -from task import Net, get_parameters, set_parameters - - -# Define metric aggregation function -def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: - examples = [num_examples for num_examples, _ in metrics] - - # Multiply accuracy of each client by number of examples used - train_losses = [num_examples * m["train_loss"] for num_examples, m in metrics] - train_accuracies = [ - num_examples * m["train_accuracy"] for num_examples, m in metrics - ] - val_losses = [num_examples * m["val_loss"] for num_examples, m in metrics] - val_accuracies = [num_examples * m["val_accuracy"] for num_examples, m in metrics] - - # Aggregate and return custom metric (weighted average) - return { - "train_loss": sum(train_losses) / sum(examples), - "train_accuracy": sum(train_accuracies) / sum(examples), - "val_loss": sum(val_losses) / sum(examples), - "val_accuracy": sum(val_accuracies) / sum(examples), - } - - -# -------------------------------------------------------------------------- Driver SDK -driver = GrpcDriver(driver_service_address="0.0.0.0:9091", certificates=None) -# -------------------------------------------------------------------------- Driver SDK - -anonymous_client_nodes = False -num_client_nodes_per_round = 2 -sleep_time = 1 -num_rounds = 3 -parameters = ndarrays_to_parameters(get_parameters(net=Net())) - -# -------------------------------------------------------------------------- Driver SDK -driver.connect() -create_run_res: driver_pb2.CreateRunResponse = driver.create_run( - req=driver_pb2.CreateRunRequest() -) -# -------------------------------------------------------------------------- Driver SDK - -run_id = create_run_res.run_id -print(f"Created run id {run_id}") - -history = History() -for server_round in range(num_rounds): - print(f"Commencing server round {server_round + 1}") - - # List of sampled node IDs in this round - sampled_nodes: List[node_pb2.Node] = [] - - # Sample node ids - if anonymous_client_nodes: - # If we're working with anonymous clients, we don't know their identities, and - # we don't know how many of them we have. We, therefore, have to assume that - # enough anonymous client nodes are available or become available over time. - # - # To schedule a TaskIns for an anonymous client node, we set the node_id to 0 - # (and `anonymous` to True) - # Here, we create an array with only zeros in it: - sampled_node_ids: List[int] = [0] * num_client_nodes_per_round - sampled_nodes = [ - node_pb2.Node(node_id=node_id, anonymous=False) - for node_id in sampled_node_ids - ] - else: - # If our client nodes have identiy (i.e., they are not anonymous), we can get - # those IDs from the Driver API using `get_nodes`. If enough clients are - # available via the Driver API, we can select a subset by taking a random - # sample. - # - # The Driver API might not immediately return enough client node IDs, so we - # loop and wait until enough client nodes are available. - while True: - # Get a list of node ID's from the server - get_nodes_req = driver_pb2.GetNodesRequest(run_id=run_id) - - # ---------------------------------------------------------------------- Driver SDK - get_nodes_res: driver_pb2.GetNodesResponse = driver.get_nodes( - req=get_nodes_req - ) - # ---------------------------------------------------------------------- Driver SDK - - all_nodes: List[node_pb2.Node] = get_nodes_res.nodes - print(f"Got {len(all_nodes)} client nodes") - - if len(all_nodes) >= num_client_nodes_per_round: - # Sample client nodes - sampled_nodes = random.sample(all_nodes, num_client_nodes_per_round) - break - - time.sleep(3) - - # Log sampled node IDs - print(f"Sampled {len(sampled_nodes)} node IDs: {sampled_nodes}") - time.sleep(sleep_time) - - # Schedule a task for all sampled nodes - fit_ins: FitIns = FitIns(parameters=parameters, config={}) - server_message_proto: transport_pb2.ServerMessage = serde.server_message_to_proto( - server_message=ServerMessage(fit_ins=fit_ins) - ) - task_ins_list: List[task_pb2.TaskIns] = [] - for sampled_node in sampled_nodes: - new_task_ins = task_pb2.TaskIns( - task_id="", # Do not set, will be created and set by the DriverAPI - group_id="", - run_id=run_id, - task=task_pb2.Task( - producer=node_pb2.Node( - node_id=0, - anonymous=True, - ), - consumer=sampled_node, - legacy_server_message=server_message_proto, - ), - ) - task_ins_list.append(new_task_ins) - - push_task_ins_req = driver_pb2.PushTaskInsRequest(task_ins_list=task_ins_list) - - # ---------------------------------------------------------------------- Driver SDK - push_task_ins_res: driver_pb2.PushTaskInsResponse = driver.push_task_ins( - req=push_task_ins_req - ) - # ---------------------------------------------------------------------- Driver SDK - - print( - f"Scheduled {len(push_task_ins_res.task_ids)} tasks: {push_task_ins_res.task_ids}" - ) - - time.sleep(sleep_time) - - # Wait for results, ignore empty task_ids - task_ids: List[str] = [ - task_id for task_id in push_task_ins_res.task_ids if task_id != "" - ] - all_task_res: List[task_pb2.TaskRes] = [] - while True: - pull_task_res_req = driver_pb2.PullTaskResRequest( - node=node_pb2.Node(node_id=0, anonymous=True), - task_ids=task_ids, - ) - - # ------------------------------------------------------------------ Driver SDK - pull_task_res_res: driver_pb2.PullTaskResResponse = driver.pull_task_res( - req=pull_task_res_req - ) - # ------------------------------------------------------------------ Driver SDK - - task_res_list: List[task_pb2.TaskRes] = pull_task_res_res.task_res_list - print(f"Got {len(task_res_list)} results") - - time.sleep(sleep_time) - - all_task_res += task_res_list - if len(all_task_res) == len(task_ids): - break - - # Collect correct results - node_messages: List[ClientMessage] = [] - for task_res in all_task_res: - if task_res.task.HasField("legacy_client_message"): - node_messages.append(task_res.task.legacy_client_message) - print(f"Received {len(node_messages)} results") - - weights_results: List[Tuple[NDArrays, int]] = [] - metrics_results: List = [] - for node_message in node_messages: - if not node_message.fit_res: - continue - fit_res = node_message.fit_res - # Aggregate only if the status is OK - if fit_res.status.code != Code.OK.value: - continue - weights_results.append( - (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) - ) - metrics_results.append( - (fit_res.num_examples, serde.metrics_from_proto(fit_res.metrics)) - ) - - # Aggregate parameters (FedAvg) - parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results)) - parameters = parameters_aggregated - - # Aggregate metrics - metrics_aggregated = weighted_average(metrics_results) - history.add_metrics_distributed_fit( - server_round=server_round, metrics=metrics_aggregated - ) - print("Round ", server_round, " metrics: ", metrics_aggregated) - - # Slow down the start of the next round - time.sleep(sleep_time) - -print("app_fit: losses_distributed %s", str(history.losses_distributed)) -print("app_fit: metrics_distributed_fit %s", str(history.metrics_distributed_fit)) -print("app_fit: metrics_distributed %s", str(history.metrics_distributed)) -print("app_fit: losses_centralized %s", str(history.losses_centralized)) -print("app_fit: metrics_centralized %s", str(history.metrics_centralized)) - -# -------------------------------------------------------------------------- Driver SDK -driver.disconnect() -# -------------------------------------------------------------------------- Driver SDK -print("Driver disconnected") diff --git a/examples/mxnet-from-centralized-to-federated/client.py b/examples/mxnet-from-centralized-to-federated/client.py index 3a3a9de146ef..bb666a26508e 100644 --- a/examples/mxnet-from-centralized-to-federated/client.py +++ b/examples/mxnet-from-centralized-to-federated/client.py @@ -1,6 +1,5 @@ """Flower client example using MXNet for MNIST classification.""" - from typing import Dict, List, Tuple import flwr as fl diff --git a/examples/mxnet-from-centralized-to-federated/mxnet_mnist.py b/examples/mxnet-from-centralized-to-federated/mxnet_mnist.py index a53ed018eb48..5cf39da7c9ca 100644 --- a/examples/mxnet-from-centralized-to-federated/mxnet_mnist.py +++ b/examples/mxnet-from-centralized-to-federated/mxnet_mnist.py @@ -5,7 +5,6 @@ https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/image/mnist.html """ - from typing import List, Tuple import mxnet as mx from mxnet import gluon diff --git a/examples/mxnet-from-centralized-to-federated/pyproject.toml b/examples/mxnet-from-centralized-to-federated/pyproject.toml index 952683eb90f6..b00b3ddfe412 100644 --- a/examples/mxnet-from-centralized-to-federated/pyproject.toml +++ b/examples/mxnet-from-centralized-to-federated/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" name = "mxnet_example" version = "0.1.0" description = "MXNet example with MNIST and CNN" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" diff --git a/examples/mxnet-from-centralized-to-federated/server.py b/examples/mxnet-from-centralized-to-federated/server.py index 29cbce1884d1..871aa4e8ec99 100644 --- a/examples/mxnet-from-centralized-to-federated/server.py +++ b/examples/mxnet-from-centralized-to-federated/server.py @@ -1,6 +1,5 @@ """Flower server example.""" - import flwr as fl if __name__ == "__main__": diff --git a/examples/opacus/dp_cifar_client.py b/examples/opacus/dp_cifar_client.py index bab1451ba707..cc30e7728222 100644 --- a/examples/opacus/dp_cifar_client.py +++ b/examples/opacus/dp_cifar_client.py @@ -28,7 +28,7 @@ def load_data(): model = Net() trainloader, testloader, sample_rate = load_data() -fl.client.start_numpy_client( +fl.client.start_client( server_address="127.0.0.1:8080", - client=DPCifarClient(model, trainloader, testloader), + client=DPCifarClient(model, trainloader, testloader).to_client(), ) diff --git a/examples/opacus/dp_cifar_simulation.py b/examples/opacus/dp_cifar_simulation.py index 14a9d037685b..d957caf8785c 100644 --- a/examples/opacus/dp_cifar_simulation.py +++ b/examples/opacus/dp_cifar_simulation.py @@ -1,14 +1,14 @@ import math from collections import OrderedDict -from typing import Callable, Optional, Tuple +from typing import Callable, Dict, Optional, Tuple import flwr as fl import numpy as np import torch import torchvision.transforms as transforms -from opacus.dp_model_inspector import DPModelInspector from torch.utils.data import DataLoader from torchvision.datasets import CIFAR10 +from flwr.common.typing import Scalar from dp_cifar_main import DEVICE, PARAMS, DPCifarClient, Net, test @@ -23,8 +23,6 @@ def client_fn(cid: str) -> fl.client.Client: # Load model. model = Net() # Check model is compatible with Opacus. - # inspector = DPModelInspector() - # print(f"Is the model valid? {inspector.validate(model)}") # Load data partition (divide CIFAR10 into NUM_CLIENTS distinct partitions, using 30% for validation). transform = transforms.Compose( @@ -45,12 +43,14 @@ def client_fn(cid: str) -> fl.client.Client: client_trainloader = DataLoader(client_trainset, PARAMS["batch_size"]) client_testloader = DataLoader(client_testset, PARAMS["batch_size"]) - return DPCifarClient(model, client_trainloader, client_testloader) + return DPCifarClient(model, client_trainloader, client_testloader).to_client() # Define an evaluation function for centralized evaluation (using whole CIFAR10 testset). def get_evaluate_fn() -> Callable[[fl.common.NDArrays], Optional[Tuple[float, float]]]: - def evaluate(weights: fl.common.NDArrays) -> Optional[Tuple[float, float]]: + def evaluate( + server_round: int, parameters: fl.common.NDArrays, config: Dict[str, Scalar] + ): transform = transforms.Compose( [ transforms.ToTensor(), @@ -63,7 +63,7 @@ def evaluate(weights: fl.common.NDArrays) -> Optional[Tuple[float, float]]: state_dict = OrderedDict( { k: torch.tensor(np.atleast_1d(v)) - for k, v in zip(model.state_dict().keys(), weights) + for k, v in zip(model.state_dict().keys(), parameters) } ) model.load_state_dict(state_dict, strict=True) @@ -82,7 +82,7 @@ def main() -> None: client_fn=client_fn, num_clients=NUM_CLIENTS, client_resources={"num_cpus": 1}, - num_rounds=3, + config=fl.server.ServerConfig(num_rounds=3), strategy=fl.server.strategy.FedAvg( fraction_fit=0.1, fraction_evaluate=0.1, evaluate_fn=get_evaluate_fn() ), diff --git a/examples/opacus/pyproject.toml b/examples/opacus/pyproject.toml index af0eaf596fbf..26914fa27aa4 100644 --- a/examples/opacus/pyproject.toml +++ b/examples/opacus/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" name = "flwr_opacus" version = "0.1.0" description = "Differentially Private Federated Learning with Opacus and Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" diff --git a/examples/pytorch-federated-variational-autoencoder/client.py b/examples/pytorch-federated-variational-autoencoder/client.py index ceb55c79f564..fc71f7e70c0b 100644 --- a/examples/pytorch-federated-variational-autoencoder/client.py +++ b/examples/pytorch-federated-variational-autoencoder/client.py @@ -93,7 +93,9 @@ def evaluate(self, parameters, config): loss = test(net, testloader) return float(loss), len(testloader), {} - fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=CifarClient()) + fl.client.start_client( + server_address="127.0.0.1:8080", client=CifarClient().to_client() + ) if __name__ == "__main__": diff --git a/examples/pytorch-federated-variational-autoencoder/pyproject.toml b/examples/pytorch-federated-variational-autoencoder/pyproject.toml index 116140306f62..bc1f85803682 100644 --- a/examples/pytorch-federated-variational-autoencoder/pyproject.toml +++ b/examples/pytorch-federated-variational-autoencoder/pyproject.toml @@ -2,7 +2,7 @@ name = "pytorch_federated_variational_autoencoder" version = "0.1.0" description = "Federated Variational Autoencoder Example" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" diff --git a/examples/pytorch-from-centralized-to-federated/README.md b/examples/pytorch-from-centralized-to-federated/README.md index fccb14158ecd..06ee89dddcac 100644 --- a/examples/pytorch-from-centralized-to-federated/README.md +++ b/examples/pytorch-from-centralized-to-federated/README.md @@ -2,7 +2,7 @@ This example demonstrates how an already existing centralized PyTorch-based machine learning project can be federated with Flower. -This introductory example for Flower uses PyTorch, but you're not required to be a PyTorch expert to run the example. The example will help you to understand how Flower can be used to build federated learning use cases based on existing machine learning projects. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) to download, partition and preprocess the CIFAR-10 dataset. +This introductory example for Flower uses PyTorch, but you're not required to be a PyTorch expert to run the example. The example will help you to understand how Flower can be used to build federated learning use cases based on existing machine learning projects. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) to download, partition and preprocess the CIFAR-10 dataset. ## Project Setup diff --git a/examples/pytorch-from-centralized-to-federated/cifar.py b/examples/pytorch-from-centralized-to-federated/cifar.py index e8f3ec3fd724..277a21da2e70 100644 --- a/examples/pytorch-from-centralized-to-federated/cifar.py +++ b/examples/pytorch-from-centralized-to-federated/cifar.py @@ -51,10 +51,10 @@ def forward(self, x: Tensor) -> Tensor: return x -def load_data(node_id: int): +def load_data(partition_id: int): """Load partition CIFAR10 data.""" fds = FederatedDataset(dataset="cifar10", partitioners={"train": 10}) - partition = fds.load_partition(node_id) + partition = fds.load_partition(partition_id) # Divide data on each node: 80% train, 20% test partition_train_test = partition.train_test_split(test_size=0.2) pytorch_transforms = Compose( diff --git a/examples/pytorch-from-centralized-to-federated/client.py b/examples/pytorch-from-centralized-to-federated/client.py index 61c7e7f762b3..9df4739e0aab 100644 --- a/examples/pytorch-from-centralized-to-federated/client.py +++ b/examples/pytorch-from-centralized-to-federated/client.py @@ -1,4 +1,5 @@ """Flower client example using PyTorch for CIFAR-10 image classification.""" + import argparse from collections import OrderedDict from typing import Dict, List, Tuple @@ -80,11 +81,11 @@ def evaluate( def main() -> None: """Load data, start CifarClient.""" parser = argparse.ArgumentParser(description="Flower") - parser.add_argument("--node-id", type=int, required=True, choices=range(0, 10)) + parser.add_argument("--partition-id", type=int, required=True, choices=range(0, 10)) args = parser.parse_args() # Load data - trainloader, testloader = cifar.load_data(args.node_id) + trainloader, testloader = cifar.load_data(args.partition_id) # Load model model = cifar.Net().to(DEVICE).train() @@ -93,8 +94,8 @@ def main() -> None: _ = model(next(iter(trainloader))["img"].to(DEVICE)) # Start client - client = CifarClient(model, trainloader, testloader) - fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=client) + client = CifarClient(model, trainloader, testloader).to_client() + fl.client.start_client(server_address="127.0.0.1:8080", client=client) if __name__ == "__main__": diff --git a/examples/pytorch-from-centralized-to-federated/pyproject.toml b/examples/pytorch-from-centralized-to-federated/pyproject.toml index 6d6f138a0aea..3d1559e3a515 100644 --- a/examples/pytorch-from-centralized-to-federated/pyproject.toml +++ b/examples/pytorch-from-centralized-to-federated/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.masonry.api" name = "quickstart-pytorch" version = "0.1.0" description = "PyTorch: From Centralized To Federated with Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" diff --git a/examples/pytorch-from-centralized-to-federated/run.sh b/examples/pytorch-from-centralized-to-federated/run.sh index 1ed51dd787ac..6ddf6ad476b4 100755 --- a/examples/pytorch-from-centralized-to-federated/run.sh +++ b/examples/pytorch-from-centralized-to-federated/run.sh @@ -6,7 +6,7 @@ sleep 3 # Sleep for 3s to give the server enough time to start for i in $(seq 0 1); do echo "Starting client $i" - python client.py --node-id $i & + python client.py --partition-id $i & done # This will allow you to use CTRL+C to stop all background processes diff --git a/examples/pytorch-from-centralized-to-federated/server.py b/examples/pytorch-from-centralized-to-federated/server.py index 42f34b3a78e9..5190d690dc20 100644 --- a/examples/pytorch-from-centralized-to-federated/server.py +++ b/examples/pytorch-from-centralized-to-federated/server.py @@ -1,6 +1,5 @@ """Flower server example.""" - from typing import List, Tuple import flwr as fl diff --git a/examples/quickstart-cpp/driver.py b/examples/quickstart-cpp/driver.py index 037623ee77cf..f19cf0e9bd98 100644 --- a/examples/quickstart-cpp/driver.py +++ b/examples/quickstart-cpp/driver.py @@ -3,7 +3,7 @@ # Start Flower server for three rounds of federated learning if __name__ == "__main__": - fl.driver.start_driver( + fl.server.start_driver( server_address="0.0.0.0:9091", config=fl.server.ServerConfig(num_rounds=3), strategy=FedAvgCpp(), diff --git a/examples/quickstart-fastai/client.py b/examples/quickstart-fastai/client.py index a88abbe525dc..6bb2a751d544 100644 --- a/examples/quickstart-fastai/client.py +++ b/examples/quickstart-fastai/client.py @@ -43,7 +43,7 @@ def evaluate(self, parameters, config): # Start Flower client -fl.client.start_numpy_client( +fl.client.start_client( server_address="127.0.0.1:8080", - client=FlowerClient(), + client=FlowerClient().to_client(), ) diff --git a/examples/quickstart-fastai/pyproject.toml b/examples/quickstart-fastai/pyproject.toml index ffaa97267493..19a25291a6af 100644 --- a/examples/quickstart-fastai/pyproject.toml +++ b/examples/quickstart-fastai/pyproject.toml @@ -6,9 +6,11 @@ build-backend = "poetry.core.masonry.api" name = "quickstart-fastai" version = "0.1.0" description = "Fastai Federated Learning Quickstart with Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] -python = ">=3.8,<3.10" -flwr = "^1.0.0" -fastai = "^2.7.10" +python = ">=3.8,<3.11" +flwr = ">=1.0,<2.0" +fastai = "2.7.14" +torch = "2.2.0" +torchvision = "0.17.0" diff --git a/examples/quickstart-fastai/requirements.txt b/examples/quickstart-fastai/requirements.txt index 0a7f315018a2..9c6e8d77293a 100644 --- a/examples/quickstart-fastai/requirements.txt +++ b/examples/quickstart-fastai/requirements.txt @@ -1,3 +1,4 @@ -fastai~=2.7.12 -flwr~=1.4.0 -torch~=2.0.1 +flwr>=1.0, <2.0 +fastai==2.7.14 +torch==2.2.0 +torchvision==0.17.0 diff --git a/examples/quickstart-huggingface/README.md b/examples/quickstart-huggingface/README.md index fd868aa1fcce..ce7790cd4af5 100644 --- a/examples/quickstart-huggingface/README.md +++ b/examples/quickstart-huggingface/README.md @@ -1,6 +1,6 @@ # Federated HuggingFace Transformers using Flower and PyTorch -This introductory example to using [HuggingFace](https://huggingface.co) Transformers with Flower with PyTorch. This example has been extended from the [quickstart-pytorch](https://flower.dev/docs/examples/quickstart-pytorch.html) example. The training script closely follows the [HuggingFace course](https://huggingface.co/course/chapter3?fw=pt), so you are encouraged to check that out for a detailed explanation of the transformer pipeline. +This introductory example to using [HuggingFace](https://huggingface.co) Transformers with Flower with PyTorch. This example has been extended from the [quickstart-pytorch](https://flower.ai/docs/examples/quickstart-pytorch.html) example. The training script closely follows the [HuggingFace course](https://huggingface.co/course/chapter3?fw=pt), so you are encouraged to check that out for a detailed explanation of the transformer pipeline. Like `quickstart-pytorch`, running this example in itself is also meant to be quite easy. @@ -62,13 +62,13 @@ Now you are ready to start the Flower clients which will participate in the lear Start client 1 in the first terminal: ```shell -python3 client.py --node-id 0 +python3 client.py --partition-id 0 ``` Start client 2 in the second terminal: ```shell -python3 client.py --node-id 1 +python3 client.py --partition-id 1 ``` You will see that PyTorch is starting a federated training. diff --git a/examples/quickstart-huggingface/client.py b/examples/quickstart-huggingface/client.py index 85792df43e53..9be08d0cbcf4 100644 --- a/examples/quickstart-huggingface/client.py +++ b/examples/quickstart-huggingface/client.py @@ -17,10 +17,10 @@ CHECKPOINT = "distilbert-base-uncased" # transformer model checkpoint -def load_data(node_id): +def load_data(partition_id): """Load IMDB data (training and eval)""" fds = FederatedDataset(dataset="imdb", partitioners={"train": 1_000}) - partition = fds.load_partition(node_id) + partition = fds.load_partition(partition_id) # Divide data: 80% train, 20% test partition_train_test = partition.train_test_split(test_size=0.2) @@ -78,12 +78,12 @@ def test(net, testloader): return loss, accuracy -def main(node_id): +def main(partition_id): net = AutoModelForSequenceClassification.from_pretrained( CHECKPOINT, num_labels=2 ).to(DEVICE) - trainloader, testloader = load_data(node_id) + trainloader, testloader = load_data(partition_id) # Flower client class IMDBClient(fl.client.NumPyClient): @@ -108,18 +108,20 @@ def evaluate(self, parameters, config): return float(loss), len(testloader), {"accuracy": float(accuracy)} # Start client - fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=IMDBClient()) + fl.client.start_client( + server_address="127.0.0.1:8080", client=IMDBClient().to_client() + ) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Flower") parser.add_argument( - "--node-id", + "--partition-id", choices=list(range(1_000)), required=True, type=int, help="Partition of the dataset divided into 1,000 iid partitions created " "artificially.", ) - node_id = parser.parse_args().node_id - main(node_id) + partition_id = parser.parse_args().partition_id + main(partition_id) diff --git a/examples/quickstart-huggingface/pyproject.toml b/examples/quickstart-huggingface/pyproject.toml index 50ba0b37f8d2..2b46804d7b45 100644 --- a/examples/quickstart-huggingface/pyproject.toml +++ b/examples/quickstart-huggingface/pyproject.toml @@ -7,8 +7,8 @@ name = "quickstart-huggingface" version = "0.1.0" description = "Hugging Face Transformers Federated Learning Quickstart with Flower" authors = [ - "The Flower Authors ", - "Kaushik Amar Das " + "The Flower Authors ", + "Kaushik Amar Das ", ] [tool.poetry.dependencies] diff --git a/examples/quickstart-huggingface/run.sh b/examples/quickstart-huggingface/run.sh index e722a24a21a9..fa989eab1471 100755 --- a/examples/quickstart-huggingface/run.sh +++ b/examples/quickstart-huggingface/run.sh @@ -6,7 +6,7 @@ sleep 3 # Sleep for 3s to give the server enough time to start for i in `seq 0 1`; do echo "Starting client $i" - python client.py --node-id ${i}& + python client.py --partition-id ${i}& done # This will allow you to use CTRL+C to stop all background processes diff --git a/examples/quickstart-jax/README.md b/examples/quickstart-jax/README.md index 9208337b86de..836adf558d88 100644 --- a/examples/quickstart-jax/README.md +++ b/examples/quickstart-jax/README.md @@ -52,7 +52,7 @@ pip install -r requirements.txt ## Run JAX Federated -This JAX example is based on the [Linear Regression with JAX](https://coax.readthedocs.io/en/latest/examples/linear_regression/jax.html) tutorial and uses a sklearn dataset (generating a random dataset for a regression pronlem). Feel free to consult the tutorial if you want to get a better understanding of JAX. If you play around with the dataset, please keep in mind that the data samples are generated randomly depending on the settings being done while calling the dataset function. Please checkout out the [scikit-learn tutorial for further information](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_regression.html). The file `jax_training.py` contains all the steps that are described in the tutorial. It loads the train and test dataset and a linear regression model, trains the model with the training set, and evaluates the trained model on the test set. +This JAX example is based on the [Linear Regression with JAX](https://coax.readthedocs.io/en/latest/examples/linear_regression/jax.html) tutorial and uses a sklearn dataset (generating a random dataset for a regression problem). Feel free to consult the tutorial if you want to get a better understanding of JAX. If you play around with the dataset, please keep in mind that the data samples are generated randomly depending on the settings being done while calling the dataset function. Please checkout out the [scikit-learn tutorial for further information](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_regression.html). The file `jax_training.py` contains all the steps that are described in the tutorial. It loads the train and test dataset and a linear regression model, trains the model with the training set, and evaluates the trained model on the test set. The only things we need are a simple Flower server (in `server.py`) and a Flower client (in `client.py`). The Flower client basically takes model and training code tells Flower how to call it. diff --git a/examples/quickstart-jax/client.py b/examples/quickstart-jax/client.py index f9b056276deb..2257a3d6daa3 100644 --- a/examples/quickstart-jax/client.py +++ b/examples/quickstart-jax/client.py @@ -1,6 +1,5 @@ """Flower client example using JAX for linear regression.""" - from typing import Dict, List, Tuple, Callable import flwr as fl @@ -52,4 +51,6 @@ def evaluate( # Start Flower client -fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=FlowerClient()) +fl.client.start_client( + server_address="127.0.0.1:8080", client=FlowerClient().to_client() +) diff --git a/examples/quickstart-jax/jax_training.py b/examples/quickstart-jax/jax_training.py index 2b523a08516e..a2e23a0927bc 100644 --- a/examples/quickstart-jax/jax_training.py +++ b/examples/quickstart-jax/jax_training.py @@ -7,7 +7,6 @@ please read the JAX documentation or the mentioned tutorial. """ - from typing import Dict, List, Tuple, Callable import jax import jax.numpy as jnp diff --git a/examples/quickstart-jax/pyproject.toml b/examples/quickstart-jax/pyproject.toml index 41b4462d0a14..c956191369b5 100644 --- a/examples/quickstart-jax/pyproject.toml +++ b/examples/quickstart-jax/pyproject.toml @@ -2,7 +2,7 @@ name = "jax_example" version = "0.1.0" description = "JAX example training a linear regression model with federated learning" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" diff --git a/examples/quickstart-mlcube/client.py b/examples/quickstart-mlcube/client.py index 0a1962d8da8a..46ddd45f52ce 100644 --- a/examples/quickstart-mlcube/client.py +++ b/examples/quickstart-mlcube/client.py @@ -43,8 +43,9 @@ def main(): os.path.dirname(os.path.abspath(__file__)), "workspaces", workspace_name ) - fl.client.start_numpy_client( - server_address="0.0.0.0:8080", client=MLCubeClient(workspace=workspace) + fl.client.start_client( + server_address="0.0.0.0:8080", + client=MLCubeClient(workspace=workspace).to_client(), ) diff --git a/examples/quickstart-mlcube/pyproject.toml b/examples/quickstart-mlcube/pyproject.toml index 0d42fc3b2898..a2862bd5ebb7 100644 --- a/examples/quickstart-mlcube/pyproject.toml +++ b/examples/quickstart-mlcube/pyproject.toml @@ -6,13 +6,13 @@ build-backend = "poetry.masonry.api" name = "quickstart-ml-cube" version = "0.1.0" description = "Keras Federated Learning Quickstart with Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" flwr = ">=1.0,<2.0" # For development: { path = "../../", develop = true } -tensorflow-cpu = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\""} -tensorflow-macos = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\""} +tensorflow-cpu = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\"" } +tensorflow-macos = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\"" } mlcube = "0.0.9" mlcube-docker = "0.0.9" tensorflow-estimator = ">=2.9.1,<2.11.1 || >2.11.1" diff --git a/examples/quickstart-mlx/README.md b/examples/quickstart-mlx/README.md index d94a87a014f7..cca55bcb946a 100644 --- a/examples/quickstart-mlx/README.md +++ b/examples/quickstart-mlx/README.md @@ -66,19 +66,19 @@ following commands. Start a first client in the first terminal: ```shell -python3 client.py --node-id 0 +python3 client.py --partition-id 0 ``` And another one in the second terminal: ```shell -python3 client.py --node-id 1 +python3 client.py --partition-id 1 ``` If you want to utilize your GPU, you can use the `--gpu` argument: ```shell -python3 client.py --gpu --node-id 2 +python3 client.py --gpu --partition-id 2 ``` Note that you can start many more clients if you want, but each will have to be in its own terminal. @@ -96,7 +96,7 @@ We will use `flwr_datasets` to easily download and partition the `MNIST` dataset ```python fds = FederatedDataset(dataset="mnist", partitioners={"train": 3}) -partition = fds.load_partition(node_id = args.node_id) +partition = fds.load_partition(partition_id = args.partition_id) partition_splits = partition.train_test_split(test_size=0.2) partition_splits['train'].set_format("numpy") diff --git a/examples/quickstart-mlx/client.py b/examples/quickstart-mlx/client.py index 3b506399a5f1..faba2b94d6bd 100644 --- a/examples/quickstart-mlx/client.py +++ b/examples/quickstart-mlx/client.py @@ -89,7 +89,7 @@ def evaluate(self, parameters, config): parser = argparse.ArgumentParser("Train a simple MLP on MNIST with MLX.") parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.") parser.add_argument( - "--node-id", + "--partition-id", choices=[0, 1, 2], type=int, help="Partition of the dataset divided into 3 iid partitions created artificially.", @@ -106,7 +106,7 @@ def evaluate(self, parameters, config): learning_rate = 1e-1 fds = FederatedDataset(dataset="mnist", partitioners={"train": 3}) - partition = fds.load_partition(node_id=args.node_id) + partition = fds.load_partition(partition_id=args.partition_id) partition_splits = partition.train_test_split(test_size=0.2) partition_splits["train"].set_format("numpy") diff --git a/examples/quickstart-mlx/pyproject.toml b/examples/quickstart-mlx/pyproject.toml index deb541c5ba9c..752040b6aaa9 100644 --- a/examples/quickstart-mlx/pyproject.toml +++ b/examples/quickstart-mlx/pyproject.toml @@ -6,11 +6,11 @@ build-backend = "poetry.core.masonry.api" name = "quickstart-mlx" version = "0.1.0" description = "MLX Federated Learning Quickstart with Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" flwr = ">=1.0,<2.0" mlx = "==0.0.3" numpy = "==1.24.4" -flwr-datasets = {extras = ["vision"], version = "^0.0.2"} +flwr-datasets = { extras = ["vision"], version = "^0.0.2" } diff --git a/examples/quickstart-mlx/run.sh b/examples/quickstart-mlx/run.sh index 70281049517d..40d211848c07 100755 --- a/examples/quickstart-mlx/run.sh +++ b/examples/quickstart-mlx/run.sh @@ -8,7 +8,7 @@ sleep 3 # Sleep for 3s to give the server enough time to start for i in $(seq 0 1); do echo "Starting client $i" - python client.py --node-id $i & + python client.py --partition-id $i & done # Enable CTRL+C to stop all background processes diff --git a/examples/quickstart-monai/.gitignore b/examples/quickstart-monai/.gitignore new file mode 100644 index 000000000000..a218cab9669e --- /dev/null +++ b/examples/quickstart-monai/.gitignore @@ -0,0 +1 @@ +MedNIST* diff --git a/examples/quickstart-monai/README.md b/examples/quickstart-monai/README.md new file mode 100644 index 000000000000..4a9afef4f86a --- /dev/null +++ b/examples/quickstart-monai/README.md @@ -0,0 +1,85 @@ +# Flower Example using MONAI + +This introductory example to Flower uses MONAI, but deep knowledge of MONAI is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. +Running this example in itself is quite easy. + +[MONAI](https://docs.monai.io/en/latest/index.html)(Medical Open Network for AI) is a PyTorch-based, open-source framework for deep learning in healthcare imaging, part of the PyTorch Ecosystem. + +Its ambitions are: + +- developing a community of academic, industrial and clinical researchers collaborating on a common foundation; + +- creating state-of-the-art, end-to-end training workflows for healthcare imaging; + +- providing researchers with an optimized and standardized way to create and evaluate deep learning models. + +## Project Setup + +Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you: + +```shell +git clone --depth=1 https://github.com/adap/flower.git _tmp && mv _tmp/examples/quickstart-monai . && rm -rf _tmp && cd quickstart-monai +``` + +This will create a new directory called `quickstart-monai` containing the following files: + +```shell +-- pyproject.toml +-- requirements.txt +-- client.py +-- data.py +-- model.py +-- server.py +-- README.md +``` + +### Installing Dependencies + +Project dependencies (such as `monai` and `flwr`) are defined in `pyproject.toml` and `requirements.txt`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)) or [pip](https://pip.pypa.io/en/latest/development/), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences. + +#### Poetry + +```shell +poetry install +poetry shell +``` + +Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command: + +```shell +poetry run python3 -c "import flwr" +``` + +If you don't see any errors you're good to go! + +#### pip + +Write the command below in your terminal to install the dependencies according to the configuration file requirements.txt. + +```shell +pip install -r requirements.txt +``` + +## Run Federated Learning with MONAI and Flower + +Afterwards you are ready to start the Flower server as well as the clients. You can simply start the server in a terminal as follows: + +```shell +python3 server.py +``` + +Now you are ready to start the Flower clients which will participate in the learning. To do so simply open two more terminal windows and run the following commands. Clients will train a [DenseNet121](https://docs.monai.io/en/stable/networks.html#densenet121) from MONAI. If a GPU is present in your system, clients will use it. + +Start client 1 in the first terminal: + +```shell +python3 client.py --partition-id 0 +``` + +Start client 2 in the second terminal: + +```shell +python3 client.py --partition-id 1 +``` + +You will see that the federated training is starting. Look at the [code](https://github.com/adap/flower/tree/main/examples/quickstart-monai) for a detailed explanation. diff --git a/examples/quickstart-monai/client.py b/examples/quickstart-monai/client.py new file mode 100644 index 000000000000..0ed943da83cc --- /dev/null +++ b/examples/quickstart-monai/client.py @@ -0,0 +1,61 @@ +import argparse +import warnings +from collections import OrderedDict + +import torch +from data import load_data +from model import test, train +from monai.networks.nets.densenet import DenseNet121 + +import flwr as fl + +warnings.filterwarnings("ignore", category=UserWarning) +DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +# Define Flower client +class FlowerClient(fl.client.NumPyClient): + def __init__(self, net, trainloader, testloader, device): + self.net = net + self.trainloader = trainloader + self.testloader = testloader + self.device = device + + def get_parameters(self, config): + return [val.cpu().numpy() for _, val in self.net.state_dict().items()] + + def set_parameters(self, parameters): + params_dict = zip(self.net.state_dict().keys(), parameters) + state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) + self.net.load_state_dict(state_dict, strict=True) + + def fit(self, parameters, config): + self.set_parameters(parameters) + train(self.net, self.trainloader, epoch_num=1, device=self.device) + return self.get_parameters(config={}), len(self.trainloader), {} + + def evaluate(self, parameters, config): + self.set_parameters(parameters) + loss, accuracy = test(self.net, self.testloader, self.device) + return loss, len(self.testloader), {"accuracy": accuracy} + + +if __name__ == "__main__": + total_partitions = 10 + parser = argparse.ArgumentParser() + parser.add_argument( + "--partition-id", type=int, choices=range(total_partitions), required=True + ) + args = parser.parse_args() + + # Load model and data (simple CNN, CIFAR-10) + trainloader, _, testloader, num_class = load_data( + total_partitions, args.partition_id + ) + net = DenseNet121(spatial_dims=2, in_channels=1, out_channels=num_class).to(DEVICE) + + # Start Flower client + fl.client.start_numpy_client( + server_address="127.0.0.1:8080", + client=FlowerClient(net, trainloader, testloader, DEVICE), + ) diff --git a/examples/quickstart-monai/data.py b/examples/quickstart-monai/data.py new file mode 100644 index 000000000000..d184476522e8 --- /dev/null +++ b/examples/quickstart-monai/data.py @@ -0,0 +1,158 @@ +import os +import tarfile +from urllib import request + +import numpy as np +from monai.data import DataLoader, Dataset +from monai.transforms import ( + Compose, + EnsureChannelFirst, + LoadImage, + RandFlip, + RandRotate, + RandZoom, + ScaleIntensity, + ToTensor, +) + + +def _partition(files_list, labels_list, num_shards, index): + total_size = len(files_list) + assert total_size == len( + labels_list + ), f"List of datapoints and labels must be of the same length" + shard_size = total_size // num_shards + + # Calculate start and end indices for the shard + start_idx = index * shard_size + if index == num_shards - 1: + # Last shard takes the remainder + end_idx = total_size + else: + end_idx = start_idx + shard_size + + # Create a subset for the shard + files = files_list[start_idx:end_idx] + labels = labels_list[start_idx:end_idx] + return files, labels + + +def load_data(num_shards, index): + image_file_list, image_label_list, _, num_class = _download_data() + + # Get partition given index + files_list, labels_list = _partition( + image_file_list, image_label_list, num_shards, index + ) + + trainX, trainY, valX, valY, testX, testY = _split_data( + files_list, labels_list, len(files_list) + ) + train_transforms, val_transforms = _get_transforms() + + train_ds = MedNISTDataset(trainX, trainY, train_transforms) + train_loader = DataLoader(train_ds, batch_size=300, shuffle=True) + + val_ds = MedNISTDataset(valX, valY, val_transforms) + val_loader = DataLoader(val_ds, batch_size=300) + + test_ds = MedNISTDataset(testX, testY, val_transforms) + test_loader = DataLoader(test_ds, batch_size=300) + + return train_loader, val_loader, test_loader, num_class + + +class MedNISTDataset(Dataset): + def __init__(self, image_files, labels, transforms): + self.image_files = image_files + self.labels = labels + self.transforms = transforms + + def __len__(self): + return len(self.image_files) + + def __getitem__(self, index): + return self.transforms(self.image_files[index]), self.labels[index] + + +def _download_data(): + data_dir = "./MedNIST/" + _download_and_extract( + "https://dl.dropboxusercontent.com/s/5wwskxctvcxiuea/MedNIST.tar.gz", + os.path.join(data_dir), + ) + + class_names = sorted( + [x for x in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, x))] + ) + num_class = len(class_names) + image_files = [ + [ + os.path.join(data_dir, class_name, x) + for x in os.listdir(os.path.join(data_dir, class_name)) + ] + for class_name in class_names + ] + image_file_list = [] + image_label_list = [] + for i, class_name in enumerate(class_names): + image_file_list.extend(image_files[i]) + image_label_list.extend([i] * len(image_files[i])) + num_total = len(image_label_list) + return image_file_list, image_label_list, num_total, num_class + + +def _split_data(image_file_list, image_label_list, num_total): + valid_frac, test_frac = 0.1, 0.1 + trainX, trainY = [], [] + valX, valY = [], [] + testX, testY = [], [] + + for i in range(num_total): + rann = np.random.random() + if rann < valid_frac: + valX.append(image_file_list[i]) + valY.append(image_label_list[i]) + elif rann < test_frac + valid_frac: + testX.append(image_file_list[i]) + testY.append(image_label_list[i]) + else: + trainX.append(image_file_list[i]) + trainY.append(image_label_list[i]) + + return trainX, trainY, valX, valY, testX, testY + + +def _get_transforms(): + train_transforms = Compose( + [ + LoadImage(image_only=True), + EnsureChannelFirst(), + ScaleIntensity(), + RandRotate(range_x=15, prob=0.5, keep_size=True), + RandFlip(spatial_axis=0, prob=0.5), + RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5, keep_size=True), + ToTensor(), + ] + ) + + val_transforms = Compose( + [LoadImage(image_only=True), EnsureChannelFirst(), ScaleIntensity(), ToTensor()] + ) + + return train_transforms, val_transforms + + +def _download_and_extract(url, dest_folder): + if not os.path.isdir(dest_folder): + # Download the tar.gz file + tar_gz_filename = url.split("/")[-1] + if not os.path.isfile(tar_gz_filename): + with request.urlopen(url) as response, open( + tar_gz_filename, "wb" + ) as out_file: + out_file.write(response.read()) + + # Extract the tar.gz file + with tarfile.open(tar_gz_filename, "r:gz") as tar_ref: + tar_ref.extractall() diff --git a/examples/quickstart-monai/model.py b/examples/quickstart-monai/model.py new file mode 100644 index 000000000000..4c74d50553e4 --- /dev/null +++ b/examples/quickstart-monai/model.py @@ -0,0 +1,33 @@ +import torch + + +def train(model, train_loader, epoch_num, device): + loss_function = torch.nn.CrossEntropyLoss() + optimizer = torch.optim.Adam(model.parameters(), 1e-5) + for _ in range(epoch_num): + model.train() + for inputs, labels in train_loader: + optimizer.zero_grad() + loss_function(model(inputs.to(device)), labels.to(device)).backward() + optimizer.step() + + +def test(model, test_loader, device): + model.eval() + loss = 0.0 + y_true = list() + y_pred = list() + loss_function = torch.nn.CrossEntropyLoss() + with torch.no_grad(): + for test_images, test_labels in test_loader: + out = model(test_images.to(device)) + test_labels = test_labels.to(device) + loss += loss_function(out, test_labels).item() + pred = out.argmax(dim=1) + for i in range(len(pred)): + y_true.append(test_labels[i].item()) + y_pred.append(pred[i].item()) + accuracy = sum([1 if t == p else 0 for t, p in zip(y_true, y_pred)]) / len( + test_loader.dataset + ) + return loss, accuracy diff --git a/examples/mt-pytorch-callable/pyproject.toml b/examples/quickstart-monai/pyproject.toml similarity index 51% rename from examples/mt-pytorch-callable/pyproject.toml rename to examples/quickstart-monai/pyproject.toml index 0d1a91836006..66a56ee2270b 100644 --- a/examples/mt-pytorch-callable/pyproject.toml +++ b/examples/quickstart-monai/pyproject.toml @@ -3,14 +3,17 @@ requires = ["poetry-core>=1.4.0"] build-backend = "poetry.core.masonry.api" [tool.poetry] -name = "quickstart-pytorch" +name = "quickstart-monai" version = "0.1.0" -description = "PyTorch Federated Learning Quickstart with Flower" +description = "MONAI Federated Learning Quickstart with Flower" authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" -flwr = { path = "../../", develop = true, extras = ["simulation", "rest"] } +flwr = ">=1.0,<2.0" torch = "1.13.1" -torchvision = "0.14.1" tqdm = "4.65.0" +scikit-learn = "1.3.1" +monai = { version = "1.3.0", extras=["gdown", "nibabel", "tqdm", "itk"] } +numpy = "1.24.4" +pillow = "10.2.0" diff --git a/examples/quickstart-monai/requirements.txt b/examples/quickstart-monai/requirements.txt new file mode 100644 index 000000000000..e3f1e463c629 --- /dev/null +++ b/examples/quickstart-monai/requirements.txt @@ -0,0 +1,7 @@ +flwr>=1.0, <2.0 +torch==1.13.1 +tqdm==4.65.0 +scikit-learn==1.3.1 +monai[gdown,nibabel,tqdm,itk]==1.3.0 +numpy==1.24.4 +pillow==10.2.0 diff --git a/examples/mt-pytorch-callable/run.sh b/examples/quickstart-monai/run.sh similarity index 74% rename from examples/mt-pytorch-callable/run.sh rename to examples/quickstart-monai/run.sh index d2bf34f834b1..1da60bccb86d 100755 --- a/examples/mt-pytorch-callable/run.sh +++ b/examples/quickstart-monai/run.sh @@ -2,8 +2,7 @@ set -e cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/ -# Download the CIFAR-10 dataset -python -c "from torchvision.datasets import CIFAR10; CIFAR10('./data', download=True)" +python -c "from data import _download_data; _download_data()" echo "Starting server" python server.py & @@ -11,7 +10,7 @@ sleep 3 # Sleep for 3s to give the server enough time to start for i in `seq 0 1`; do echo "Starting client $i" - python client.py & + python client.py --partition-id $i & done # Enable CTRL+C to stop all background processes diff --git a/examples/mt-pytorch-callable/server.py b/examples/quickstart-monai/server.py similarity index 100% rename from examples/mt-pytorch-callable/server.py rename to examples/quickstart-monai/server.py diff --git a/examples/quickstart-mxnet/client.py b/examples/quickstart-mxnet/client.py index b0c937f7350f..6c2b2e99775d 100644 --- a/examples/quickstart-mxnet/client.py +++ b/examples/quickstart-mxnet/client.py @@ -5,7 +5,6 @@ https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/image/mnist.html """ - import flwr as fl import numpy as np import mxnet as mx diff --git a/examples/quickstart-mxnet/pyproject.toml b/examples/quickstart-mxnet/pyproject.toml index 952683eb90f6..b00b3ddfe412 100644 --- a/examples/quickstart-mxnet/pyproject.toml +++ b/examples/quickstart-mxnet/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" name = "mxnet_example" version = "0.1.0" description = "MXNet example with MNIST and CNN" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" diff --git a/examples/quickstart-mxnet/server.py b/examples/quickstart-mxnet/server.py index 29cbce1884d1..871aa4e8ec99 100644 --- a/examples/quickstart-mxnet/server.py +++ b/examples/quickstart-mxnet/server.py @@ -1,6 +1,5 @@ """Flower server example.""" - import flwr as fl if __name__ == "__main__": diff --git a/examples/quickstart-pandas/README.md b/examples/quickstart-pandas/README.md index a25e6ea6ee36..dd69f3ead3cb 100644 --- a/examples/quickstart-pandas/README.md +++ b/examples/quickstart-pandas/README.md @@ -1,6 +1,6 @@ # Flower Example using Pandas -This introductory example to Flower uses Pandas, but deep knowledge of Pandas is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) to +This introductory example to Flower uses Pandas, but deep knowledge of Pandas is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) to download, partition and preprocess the dataset. Running this example in itself is quite easy. @@ -70,13 +70,13 @@ Now you are ready to start the Flower clients which will participate in the lear Start client 1 in the first terminal: ```shell -$ python3 client.py --node-id 0 +$ python3 client.py --partition-id 0 ``` Start client 2 in the second terminal: ```shell -$ python3 client.py --node-id 1 +$ python3 client.py --partition-id 1 ``` -You will see that the server is printing aggregated statistics about the dataset distributed amongst clients. Have a look to the [Flower Quickstarter documentation](https://flower.dev/docs/quickstart-pandas.html) for a detailed explanation. +You will see that the server is printing aggregated statistics about the dataset distributed amongst clients. Have a look to the [Flower Quickstarter documentation](https://flower.ai/docs/quickstart-pandas.html) for a detailed explanation. diff --git a/examples/quickstart-pandas/client.py b/examples/quickstart-pandas/client.py index c2f2605594d5..c52b7c65b04c 100644 --- a/examples/quickstart-pandas/client.py +++ b/examples/quickstart-pandas/client.py @@ -42,14 +42,14 @@ def fit( parser = argparse.ArgumentParser(description="Flower") parser.add_argument( - "--node-id", + "--partition-id", type=int, choices=range(0, N_CLIENTS), required=True, - help="Specifies the node id of artificially partitioned datasets.", + help="Specifies the partition id of artificially partitioned datasets.", ) args = parser.parse_args() - partition_id = args.node_id + partition_id = args.partition_id # Load the partition data fds = FederatedDataset(dataset="hitorilabs/iris", partitioners={"train": N_CLIENTS}) @@ -59,7 +59,7 @@ def fit( X = dataset[column_names] # Start Flower client - fl.client.start_numpy_client( + fl.client.start_client( server_address="127.0.0.1:8080", - client=FlowerClient(X), + client=FlowerClient(X).to_client(), ) diff --git a/examples/quickstart-pandas/pyproject.toml b/examples/quickstart-pandas/pyproject.toml index 6229210d6488..2e6b1424bb54 100644 --- a/examples/quickstart-pandas/pyproject.toml +++ b/examples/quickstart-pandas/pyproject.toml @@ -7,7 +7,7 @@ name = "quickstart-pandas" version = "0.1.0" description = "Pandas Federated Analytics Quickstart with Flower" authors = ["Ragy Haddad "] -maintainers = ["The Flower Authors "] +maintainers = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" diff --git a/examples/quickstart-pandas/run.sh b/examples/quickstart-pandas/run.sh index 571fa8bfb3e4..2ae1e582b8cf 100755 --- a/examples/quickstart-pandas/run.sh +++ b/examples/quickstart-pandas/run.sh @@ -4,7 +4,7 @@ sleep 3 # Sleep for 3s to give the server enough time to start for i in `seq 0 1`; do echo "Starting client $i" - python client.py --node-id ${i} & + python client.py --partition-id ${i} & done # This will allow you to use CTRL+C to stop all background processes diff --git a/examples/quickstart-pytorch-lightning/README.md b/examples/quickstart-pytorch-lightning/README.md index 1287b50bca65..fb29c7e9e9ea 100644 --- a/examples/quickstart-pytorch-lightning/README.md +++ b/examples/quickstart-pytorch-lightning/README.md @@ -1,6 +1,6 @@ # Flower Example using PyTorch Lightning -This introductory example to Flower uses PyTorch, but deep knowledge of PyTorch Lightning is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. Running this example in itself is quite easy. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) to download, partition and preprocess the MNIST dataset. +This introductory example to Flower uses PyTorch, but deep knowledge of PyTorch Lightning is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. Running this example in itself is quite easy. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) to download, partition and preprocess the MNIST dataset. ## Project Setup @@ -57,20 +57,20 @@ Afterwards you are ready to start the Flower server as well as the clients. You python server.py ``` -Now you are ready to start the Flower clients which will participate in the learning. We need to specify the node id to +Now you are ready to start the Flower clients which will participate in the learning. We need to specify the partition id to use different partitions of the data on different nodes. To do so simply open two more terminal windows and run the following commands. Start client 1 in the first terminal: ```shell -python client.py --node-id 0 +python client.py --partition-id 0 ``` Start client 2 in the second terminal: ```shell -python client.py --node-id 1 +python client.py --partition-id 1 ``` You will see that PyTorch is starting a federated training. Look at the [code](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch) for a detailed explanation. diff --git a/examples/quickstart-pytorch-lightning/client.py b/examples/quickstart-pytorch-lightning/client.py index 1dabd5732b9b..6e21259cc492 100644 --- a/examples/quickstart-pytorch-lightning/client.py +++ b/examples/quickstart-pytorch-lightning/client.py @@ -58,22 +58,22 @@ def _set_parameters(model, parameters): def main() -> None: parser = argparse.ArgumentParser(description="Flower") parser.add_argument( - "--node-id", + "--partition-id", type=int, choices=range(0, 10), required=True, help="Specifies the artificial data partition", ) args = parser.parse_args() - node_id = args.node_id + partition_id = args.partition_id # Model and data model = mnist.LitAutoEncoder() - train_loader, val_loader, test_loader = mnist.load_data(node_id) + train_loader, val_loader, test_loader = mnist.load_data(partition_id) # Flower client - client = FlowerClient(model, train_loader, val_loader, test_loader) - fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=client) + client = FlowerClient(model, train_loader, val_loader, test_loader).to_client() + fl.client.start_client(server_address="127.0.0.1:8080", client=client) if __name__ == "__main__": diff --git a/examples/quickstart-pytorch-lightning/pyproject.toml b/examples/quickstart-pytorch-lightning/pyproject.toml index 853ef9c1646f..a09aaa3d65b5 100644 --- a/examples/quickstart-pytorch-lightning/pyproject.toml +++ b/examples/quickstart-pytorch-lightning/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.masonry.api" name = "quickstart-pytorch-lightning" version = "0.1.0" description = "Federated Learning Quickstart with Flower and PyTorch Lightning" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = "^3.8" diff --git a/examples/quickstart-pytorch-lightning/run.sh b/examples/quickstart-pytorch-lightning/run.sh index 60893a9a055b..62a1dac199bd 100755 --- a/examples/quickstart-pytorch-lightning/run.sh +++ b/examples/quickstart-pytorch-lightning/run.sh @@ -6,7 +6,7 @@ sleep 3 # Sleep for 3s to give the server enough time to start for i in $(seq 0 1); do echo "Starting client $i" - python client.py --node-id "${i}" & + python client.py --partition-id "${i}" & done # This will allow you to use CTRL+C to stop all background processes diff --git a/examples/quickstart-pytorch/README.md b/examples/quickstart-pytorch/README.md index 6de0dcf7ab32..02c9b4b38498 100644 --- a/examples/quickstart-pytorch/README.md +++ b/examples/quickstart-pytorch/README.md @@ -1,6 +1,6 @@ # Flower Example using PyTorch -This introductory example to Flower uses PyTorch, but deep knowledge of PyTorch is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. Running this example in itself is quite easy. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) to download, partition and preprocess the CIFAR-10 dataset. +This introductory example to Flower uses PyTorch, but deep knowledge of PyTorch is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. Running this example in itself is quite easy. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) to download, partition and preprocess the CIFAR-10 dataset. ## Project Setup @@ -55,20 +55,20 @@ Afterwards you are ready to start the Flower server as well as the clients. You python3 server.py ``` -Now you are ready to start the Flower clients which will participate in the learning. We need to specify the node id to +Now you are ready to start the Flower clients which will participate in the learning. We need to specify the partition id to use different partitions of the data on different nodes. To do so simply open two more terminal windows and run the following commands. Start client 1 in the first terminal: ```shell -python3 client.py --node-id 0 +python3 client.py --partition-id 0 ``` Start client 2 in the second terminal: ```shell -python3 client.py --node-id 1 +python3 client.py --partition-id 1 ``` You will see that PyTorch is starting a federated training. Look at the [code](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch) for a detailed explanation. diff --git a/examples/quickstart-pytorch/client.py b/examples/quickstart-pytorch/client.py index 1edb42d1ec81..e640ce111dff 100644 --- a/examples/quickstart-pytorch/client.py +++ b/examples/quickstart-pytorch/client.py @@ -69,10 +69,10 @@ def test(net, testloader): return loss, accuracy -def load_data(node_id): +def load_data(partition_id): """Load partition CIFAR10 data.""" fds = FederatedDataset(dataset="cifar10", partitioners={"train": 3}) - partition = fds.load_partition(node_id) + partition = fds.load_partition(partition_id) # Divide data on each node: 80% train, 20% test partition_train_test = partition.train_test_split(test_size=0.2) pytorch_transforms = Compose( @@ -94,20 +94,20 @@ def apply_transforms(batch): # 2. Federation of the pipeline with Flower # ############################################################################# -# Get node id +# Get partition id parser = argparse.ArgumentParser(description="Flower") parser.add_argument( - "--node-id", + "--partition-id", choices=[0, 1, 2], required=True, type=int, help="Partition of the dataset divided into 3 iid partitions created artificially.", ) -node_id = parser.parse_args().node_id +partition_id = parser.parse_args().partition_id # Load model and data (simple CNN, CIFAR-10) net = Net().to(DEVICE) -trainloader, testloader = load_data(node_id=node_id) +trainloader, testloader = load_data(partition_id=partition_id) # Define Flower client @@ -132,7 +132,7 @@ def evaluate(self, parameters, config): # Start Flower client -fl.client.start_numpy_client( +fl.client.start_client( server_address="127.0.0.1:8080", - client=FlowerClient(), + client=FlowerClient().to_client(), ) diff --git a/examples/quickstart-pytorch/pyproject.toml b/examples/quickstart-pytorch/pyproject.toml index ec6a3af8c5b4..d8e1503dd8a7 100644 --- a/examples/quickstart-pytorch/pyproject.toml +++ b/examples/quickstart-pytorch/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" name = "quickstart-pytorch" version = "0.1.0" description = "PyTorch Federated Learning Quickstart with Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" diff --git a/examples/quickstart-pytorch/run.sh b/examples/quickstart-pytorch/run.sh index cdace99bb8df..6ca9c8cafec9 100755 --- a/examples/quickstart-pytorch/run.sh +++ b/examples/quickstart-pytorch/run.sh @@ -8,7 +8,7 @@ sleep 3 # Sleep for 3s to give the server enough time to start for i in $(seq 0 1); do echo "Starting client $i" - python client.py --node-id "$i" & + python client.py --partition-id "$i" & done # Enable CTRL+C to stop all background processes diff --git a/examples/quickstart-sklearn-tabular/README.md b/examples/quickstart-sklearn-tabular/README.md index d62525c96c18..a975a9392800 100644 --- a/examples/quickstart-sklearn-tabular/README.md +++ b/examples/quickstart-sklearn-tabular/README.md @@ -3,7 +3,7 @@ This example of Flower uses `scikit-learn`'s `LogisticRegression` model to train a federated learning system on "iris" (tabular) dataset. It will help you understand how to adapt Flower for use with `scikit-learn`. -Running this example in itself is quite easy. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) to +Running this example in itself is quite easy. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) to download, partition and preprocess the dataset. ## Project Setup @@ -63,15 +63,15 @@ poetry run python3 server.py Now you are ready to start the Flower clients which will participate in the learning. To do so simply open two more terminals and run the following command in each: ```shell -poetry run python3 client.py --node-id 0 # node-id should be any of {0,1,2} +poetry run python3 client.py --partition-id 0 # partition-id should be any of {0,1,2} ``` Alternatively you can run all of it in one shell as follows: ```shell poetry run python3 server.py & -poetry run python3 client.py --node-id 0 & -poetry run python3 client.py --node-id 1 +poetry run python3 client.py --partition-id 0 & +poetry run python3 client.py --partition-id 1 ``` You will see that Flower is starting a federated training. diff --git a/examples/quickstart-sklearn-tabular/client.py b/examples/quickstart-sklearn-tabular/client.py index 5dc0e88b3c75..fcab8f5d5612 100644 --- a/examples/quickstart-sklearn-tabular/client.py +++ b/examples/quickstart-sklearn-tabular/client.py @@ -13,14 +13,14 @@ parser = argparse.ArgumentParser(description="Flower") parser.add_argument( - "--node-id", + "--partition-id", type=int, choices=range(0, N_CLIENTS), required=True, help="Specifies the artificial data partition", ) args = parser.parse_args() - partition_id = args.node_id + partition_id = args.partition_id # Load the partition data fds = FederatedDataset(dataset="hitorilabs/iris", partitioners={"train": N_CLIENTS}) diff --git a/examples/quickstart-sklearn-tabular/pyproject.toml b/examples/quickstart-sklearn-tabular/pyproject.toml index 34a78048d3b0..86eab5c38df0 100644 --- a/examples/quickstart-sklearn-tabular/pyproject.toml +++ b/examples/quickstart-sklearn-tabular/pyproject.toml @@ -7,8 +7,8 @@ name = "sklearn-mnist" version = "0.1.0" description = "Federated learning with scikit-learn and Flower" authors = [ - "The Flower Authors ", - "Kaushik Amar Das " + "The Flower Authors ", + "Kaushik Amar Das ", ] [tool.poetry.dependencies] diff --git a/examples/quickstart-sklearn-tabular/run.sh b/examples/quickstart-sklearn-tabular/run.sh index 48cee1b41b74..f770ca05f8f4 100755 --- a/examples/quickstart-sklearn-tabular/run.sh +++ b/examples/quickstart-sklearn-tabular/run.sh @@ -8,7 +8,7 @@ sleep 3 # Sleep for 3s to give the server enough time to start for i in $(seq 0 1); do echo "Starting client $i" - python client.py --node-id "${i}" & + python client.py --partition-id "${i}" & done # This will allow you to use CTRL+C to stop all background processes diff --git a/examples/quickstart-tabnet/client.py b/examples/quickstart-tabnet/client.py index da391a95710a..2289b1b55b3d 100644 --- a/examples/quickstart-tabnet/client.py +++ b/examples/quickstart-tabnet/client.py @@ -79,4 +79,6 @@ def evaluate(self, parameters, config): # Start Flower client -fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=TabNetClient()) +fl.client.start_client( + server_address="127.0.0.1:8080", client=TabNetClient().to_client() +) diff --git a/examples/quickstart-tabnet/pyproject.toml b/examples/quickstart-tabnet/pyproject.toml index 948eaf085b86..18f1979791bd 100644 --- a/examples/quickstart-tabnet/pyproject.toml +++ b/examples/quickstart-tabnet/pyproject.toml @@ -6,12 +6,12 @@ build-backend = "poetry.core.masonry.api" name = "quickstart-tabnet" version = "0.1.0" description = "Tabnet Federated Learning Quickstart with Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" flwr = ">=1.0,<2.0" -tensorflow-cpu = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\""} -tensorflow-macos = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\""} +tensorflow-cpu = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\"" } +tensorflow-macos = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\"" } tensorflow_datasets = "4.8.3" tabnet = "0.1.6" diff --git a/examples/quickstart-tensorflow/README.md b/examples/quickstart-tensorflow/README.md index 92d38c9340d7..8d5e9434b086 100644 --- a/examples/quickstart-tensorflow/README.md +++ b/examples/quickstart-tensorflow/README.md @@ -1,7 +1,7 @@ # Flower Example using TensorFlow/Keras This introductory example to Flower uses Keras but deep knowledge of Keras is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. -Running this example in itself is quite easy. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) to download, partition and preprocess the CIFAR-10 dataset. +Running this example in itself is quite easy. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) to download, partition and preprocess the CIFAR-10 dataset. ## Project Setup diff --git a/examples/quickstart-tensorflow/client.py b/examples/quickstart-tensorflow/client.py index d998adbdd899..3e2035c09311 100644 --- a/examples/quickstart-tensorflow/client.py +++ b/examples/quickstart-tensorflow/client.py @@ -11,7 +11,7 @@ # Parse arguments parser = argparse.ArgumentParser(description="Flower") parser.add_argument( - "--node-id", + "--partition-id", type=int, choices=[0, 1, 2], required=True, @@ -26,7 +26,7 @@ # Download and partition dataset fds = FederatedDataset(dataset="cifar10", partitioners={"train": 3}) -partition = fds.load_partition(args.node_id, "train") +partition = fds.load_partition(args.partition_id, "train") partition.set_format("numpy") # Divide data on each node: 80% train, 20% test @@ -52,4 +52,6 @@ def evaluate(self, parameters, config): # Start Flower client -fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=CifarClient()) +fl.client.start_client( + server_address="127.0.0.1:8080", client=CifarClient().to_client() +) diff --git a/examples/quickstart-tensorflow/pyproject.toml b/examples/quickstart-tensorflow/pyproject.toml index e027a7353181..98aeb932cab9 100644 --- a/examples/quickstart-tensorflow/pyproject.toml +++ b/examples/quickstart-tensorflow/pyproject.toml @@ -6,11 +6,11 @@ build-backend = "poetry.core.masonry.api" name = "quickstart-tensorflow" version = "0.1.0" description = "Keras Federated Learning Quickstart with Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" flwr = ">=1.0,<2.0" flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" } -tensorflow-cpu = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\""} -tensorflow-macos = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\""} +tensorflow-cpu = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\"" } +tensorflow-macos = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\"" } diff --git a/examples/quickstart-tensorflow/run.sh b/examples/quickstart-tensorflow/run.sh index 439abea8df4b..76188f197e3e 100755 --- a/examples/quickstart-tensorflow/run.sh +++ b/examples/quickstart-tensorflow/run.sh @@ -6,7 +6,7 @@ sleep 3 # Sleep for 3s to give the server enough time to start for i in `seq 0 1`; do echo "Starting client $i" - python client.py --node-id $i & + python client.py --partition-id $i & done # This will allow you to use CTRL+C to stop all background processes diff --git a/examples/secaggplus-mt/README.md b/examples/secaggplus-mt/README.md deleted file mode 100644 index 0b3b4db3942e..000000000000 --- a/examples/secaggplus-mt/README.md +++ /dev/null @@ -1,36 +0,0 @@ -# Secure Aggregation with Driver API - -This example contains highly experimental code. Please consult the regular PyTorch code examples ([quickstart](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch), [advanced](https://github.com/adap/flower/tree/main/examples/advanced-pytorch)) to learn how to use Flower with PyTorch. - -## Installing Dependencies - -Project dependencies (such as and `flwr`) are defined in `pyproject.toml`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences. - -### Poetry - -```shell -poetry install -poetry shell -``` - -Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command: - -```shell -poetry run python3 -c "import flwr" -``` - -### pip - -Write the command below in your terminal to install the dependencies according to the configuration file requirements.txt. - -```shell -pip install -r requirements.txt -``` - -If you don't see any errors you're good to go! - -## Run with Driver API - -```bash -./run.sh -``` diff --git a/examples/secaggplus-mt/client.py b/examples/secaggplus-mt/client.py deleted file mode 100644 index f0f1348ee378..000000000000 --- a/examples/secaggplus-mt/client.py +++ /dev/null @@ -1,35 +0,0 @@ -import time - -import numpy as np - -import flwr as fl -from flwr.common import Status, FitIns, FitRes, Code -from flwr.common.parameter import ndarrays_to_parameters -from flwr.client.secure_aggregation import SecAggPlusHandler - - -# Define Flower client with the SecAgg+ protocol -class FlowerClient(fl.client.Client, SecAggPlusHandler): - def fit(self, fit_ins: FitIns) -> FitRes: - ret_vec = [np.ones(3)] - ret = FitRes( - status=Status(code=Code.OK, message="Success"), - parameters=ndarrays_to_parameters(ret_vec), - num_examples=1, - metrics={}, - ) - # Force a significant delay for testing purposes - if self._shared_state.sid == 0: - print(f"Client {self._shared_state.sid} dropped for testing purposes.") - time.sleep(4) - return ret - print(f"Client {self._shared_state.sid} uploading {ret_vec[0]}...") - return ret - - -# Start Flower client -fl.client.start_client( - server_address="0.0.0.0:9092", - client=FlowerClient(), - transport="grpc-rere", -) diff --git a/examples/secaggplus-mt/driver.py b/examples/secaggplus-mt/driver.py deleted file mode 100644 index f5871f1b44e4..000000000000 --- a/examples/secaggplus-mt/driver.py +++ /dev/null @@ -1,207 +0,0 @@ -import random -import time -from typing import Dict, List, Tuple - -import numpy as np -from workflows import get_workflow_factory - -from flwr.common import Metrics, ndarrays_to_parameters -from flwr.driver import GrpcDriver -from flwr.proto import driver_pb2, node_pb2, task_pb2 -from flwr.server import History - - -# Convert instruction/result dict to/from list of TaskIns/TaskRes -def task_dict_to_task_ins_list( - task_dict: Dict[int, task_pb2.Task] -) -> List[task_pb2.TaskIns]: - def merge(_task: task_pb2.Task, _merge_task: task_pb2.Task) -> task_pb2.Task: - _task.MergeFrom(_merge_task) - return _task - - return [ - task_pb2.TaskIns( - task_id="", # Do not set, will be created and set by the DriverAPI - group_id="", - run_id=run_id, - run_id=run_id, - task=merge( - task, - task_pb2.Task( - producer=node_pb2.Node( - node_id=0, - anonymous=True, - ), - consumer=node_pb2.Node( - node_id=sampled_node_id, - # Must be False for this Secure Aggregation example - anonymous=False, - ), - ), - ), - ) - for sampled_node_id, task in task_dict.items() - ] - - -def task_res_list_to_task_dict( - task_res_list: List[task_pb2.TaskRes], -) -> Dict[int, task_pb2.Task]: - return {task_res.task.producer.node_id: task_res.task for task_res in task_res_list} - - -# Define metric aggregation function -def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: - examples = [num_examples for num_examples, _ in metrics] - - # Multiply accuracy of each client by number of examples used - train_losses = [num_examples * m["train_loss"] for num_examples, m in metrics] - train_accuracies = [ - num_examples * m["train_accuracy"] for num_examples, m in metrics - ] - val_losses = [num_examples * m["val_loss"] for num_examples, m in metrics] - val_accuracies = [num_examples * m["val_accuracy"] for num_examples, m in metrics] - - # Aggregate and return custom metric (weighted average) - return { - "train_loss": sum(train_losses) / sum(examples), - "train_accuracy": sum(train_accuracies) / sum(examples), - "val_loss": sum(val_losses) / sum(examples), - "val_accuracy": sum(val_accuracies) / sum(examples), - } - - -# -------------------------------------------------------------------------- Driver SDK -driver = GrpcDriver(driver_service_address="0.0.0.0:9091", certificates=None) -# -------------------------------------------------------------------------- Driver SDK - -anonymous_client_nodes = False -num_client_nodes_per_round = 5 -sleep_time = 0.5 -time_out = 3.9 -num_rounds = 3 -parameters = ndarrays_to_parameters([np.ones(3)]) -wf_factory = get_workflow_factory() - -# -------------------------------------------------------------------------- Driver SDK -driver.connect() -create_run_res: driver_pb2.CreateRunResponse = driver.create_run( - req=driver_pb2.CreateRunRequest() -) -# -------------------------------------------------------------------------- Driver SDK - -run_id = create_run_res.run_id -print(f"Created run id {run_id}") - -history = History() -for server_round in range(num_rounds): - print(f"Commencing server round {server_round + 1}") - - # List of sampled node IDs in this round - sampled_node_ids: List[int] = [] - - # Sample node ids - if anonymous_client_nodes: - # If we're working with anonymous clients, we don't know their identities, and - # we don't know how many of them we have. We, therefore, have to assume that - # enough anonymous client nodes are available or become available over time. - # - # To schedule a TaskIns for an anonymous client node, we set the node_id to 0 - # (and `anonymous` to True) - # Here, we create an array with only zeros in it: - sampled_node_ids = [0] * num_client_nodes_per_round - else: - # If our client nodes have identiy (i.e., they are not anonymous), we can get - # those IDs from the Driver API using `get_nodes`. If enough clients are - # available via the Driver API, we can select a subset by taking a random - # sample. - # - # The Driver API might not immediately return enough client node IDs, so we - # loop and wait until enough client nodes are available. - while True: - # Get a list of node ID's from the server - get_nodes_req = driver_pb2.GetNodesRequest(run_id=run_id) - - # ---------------------------------------------------------------------- Driver SDK - get_nodes_res: driver_pb2.GetNodesResponse = driver.get_nodes( - req=get_nodes_req - ) - # ---------------------------------------------------------------------- Driver SDK - - all_node_ids: List[int] = [node.node_id for node in get_nodes_res.nodes] - - if len(all_node_ids) >= num_client_nodes_per_round: - # Sample client nodes - sampled_node_ids = random.sample( - all_node_ids, num_client_nodes_per_round - ) - break - - time.sleep(3) - - # Log sampled node IDs - time.sleep(sleep_time) - - workflow = wf_factory(parameters, sampled_node_ids) - node_messages = None - - while True: - try: - instructions: Dict[int, task_pb2.Task] = workflow.send(node_messages) - next(workflow) - except StopIteration: - break - # Schedule a task for all sampled nodes - task_ins_list: List[task_pb2.TaskIns] = task_dict_to_task_ins_list(instructions) - - push_task_ins_req = driver_pb2.PushTaskInsRequest(task_ins_list=task_ins_list) - - # ---------------------------------------------------------------------- Driver SDK - push_task_ins_res: driver_pb2.PushTaskInsResponse = driver.push_task_ins( - req=push_task_ins_req - ) - # ---------------------------------------------------------------------- Driver SDK - - time.sleep(sleep_time) - - # Wait for results, ignore empty task_ids - start_time = time.time() - task_ids: List[str] = [ - task_id for task_id in push_task_ins_res.task_ids if task_id != "" - ] - all_task_res: List[task_pb2.TaskRes] = [] - while True: - if time.time() - start_time >= time_out: - break - pull_task_res_req = driver_pb2.PullTaskResRequest( - node=node_pb2.Node(node_id=0, anonymous=True), - task_ids=task_ids, - ) - - # ------------------------------------------------------------------ Driver SDK - pull_task_res_res: driver_pb2.PullTaskResResponse = driver.pull_task_res( - req=pull_task_res_req - ) - # ------------------------------------------------------------------ Driver SDK - - task_res_list: List[task_pb2.TaskRes] = pull_task_res_res.task_res_list - - time.sleep(sleep_time) - - all_task_res += task_res_list - if len(all_task_res) == len(task_ids): - break - - # Collect correct results - node_messages = task_res_list_to_task_dict( - [res for res in all_task_res if res.task.HasField("sa")] - ) - workflow.close() - - # Slow down the start of the next round - time.sleep(sleep_time) - -# -------------------------------------------------------------------------- Driver SDK -driver.disconnect() -# -------------------------------------------------------------------------- Driver SDK -print("Driver disconnected") diff --git a/examples/secaggplus-mt/pyproject.toml b/examples/secaggplus-mt/pyproject.toml deleted file mode 100644 index 94d8defa3316..000000000000 --- a/examples/secaggplus-mt/pyproject.toml +++ /dev/null @@ -1,13 +0,0 @@ -[build-system] -requires = ["poetry-core>=1.4.0"] -build-backend = "poetry.core.masonry.api" - -[tool.poetry] -name = "secaggplus-mt" -version = "0.1.0" -description = "Secure Aggregation with Driver API" -authors = ["The Flower Authors "] - -[tool.poetry.dependencies] -python = ">=3.8,<3.11" -flwr-nightly = { version = "^1.5.0.dev20230629", extras = ["simulation", "rest"] } diff --git a/examples/secaggplus-mt/requirements.txt b/examples/secaggplus-mt/requirements.txt deleted file mode 100644 index eeed6941afc7..000000000000 --- a/examples/secaggplus-mt/requirements.txt +++ /dev/null @@ -1 +0,0 @@ -flwr-nightly[simulation,rest] diff --git a/examples/secaggplus-mt/run.sh b/examples/secaggplus-mt/run.sh deleted file mode 100755 index 5cc769f6cbd8..000000000000 --- a/examples/secaggplus-mt/run.sh +++ /dev/null @@ -1,35 +0,0 @@ -#!/bin/bash -# Kill any currently running client.py processes -pkill -f 'python client.py' - -# Kill any currently running flower-server processes with --grpc-rere option -pkill -f 'flower-server --grpc-rere' - -# Start the flower server -echo "Starting flower server in background..." -flower-server --grpc-rere > /dev/null 2>&1 & -sleep 2 - -# Number of client processes to start -N=5 # Replace with your desired value - -echo "Starting $N clients in background..." - -# Start N client processes -for i in $(seq 1 $N) -do - python client.py > /dev/null 2>&1 & - # python client.py & - sleep 0.1 -done - -echo "Starting driver..." -python driver.py - -echo "Clearing background processes..." - -# Kill any currently running client.py processes -pkill -f 'python client.py' - -# Kill any currently running flower-server processes with --grpc-rere option -pkill -f 'flower-server --grpc-rere' diff --git a/examples/secaggplus-mt/workflows.py b/examples/secaggplus-mt/workflows.py deleted file mode 100644 index 3117e308a498..000000000000 --- a/examples/secaggplus-mt/workflows.py +++ /dev/null @@ -1,372 +0,0 @@ -import random -from logging import WARNING -from typing import Callable, Dict, Generator, List - -import numpy as np - -from flwr.common import ( - Parameters, - Scalar, - bytes_to_ndarray, - log, - ndarray_to_bytes, - ndarrays_to_parameters, - parameters_to_ndarrays, -) -from flwr.common.secure_aggregation.crypto.shamir import combine_shares -from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( - bytes_to_private_key, - bytes_to_public_key, - generate_shared_key, -) -from flwr.common.secure_aggregation.ndarrays_arithmetic import ( - factor_extract, - get_parameters_shape, - get_zero_parameters, - parameters_addition, - parameters_mod, - parameters_subtraction, -) -from flwr.common.secure_aggregation.quantization import dequantize, quantize -from flwr.common.secure_aggregation.secaggplus_constants import ( - KEY_ACTIVE_SECURE_ID_LIST, - KEY_CIPHERTEXT_LIST, - KEY_CLIPPING_RANGE, - KEY_DEAD_SECURE_ID_LIST, - KEY_DESTINATION_LIST, - KEY_MASKED_PARAMETERS, - KEY_MOD_RANGE, - KEY_PARAMETERS, - KEY_PUBLIC_KEY_1, - KEY_PUBLIC_KEY_2, - KEY_SAMPLE_NUMBER, - KEY_SECURE_ID, - KEY_SECURE_ID_LIST, - KEY_SHARE_LIST, - KEY_SHARE_NUMBER, - KEY_SOURCE_LIST, - KEY_STAGE, - KEY_TARGET_RANGE, - KEY_THRESHOLD, - STAGE_COLLECT_MASKED_INPUT, - STAGE_SETUP, - STAGE_SHARE_KEYS, - STAGE_UNMASK, -) -from flwr.common.secure_aggregation.secaggplus_utils import pseudo_rand_gen -from flwr.common.serde import named_values_from_proto, named_values_to_proto -from flwr.common.typing import Value -from flwr.proto.task_pb2 import SecureAggregation, Task - - -LOG_EXPLAIN = True - - -def get_workflow_factory() -> ( - Callable[[Parameters, List[int]], Generator[Dict[int, Task], Dict[int, Task], None]] -): - return _wrap_workflow_with_sec_agg - - -def _wrap_in_task(named_values: Dict[str, Value]) -> Task: - return Task(sa=SecureAggregation(named_values=named_values_to_proto(named_values))) - - -def _get_from_task(task: Task) -> Dict[str, Value]: - return named_values_from_proto(task.sa.named_values) - - -_secure_aggregation_configuration = { - KEY_SHARE_NUMBER: 3, - KEY_THRESHOLD: 2, - KEY_CLIPPING_RANGE: 3.0, - KEY_TARGET_RANGE: 1 << 20, - KEY_MOD_RANGE: 1 << 30, -} - - -def workflow_with_sec_agg( - parameters: Parameters, - sampled_node_ids: List[int], - sec_agg_config: Dict[str, Scalar], -) -> Generator[Dict[int, Task], Dict[int, Task], None]: - """ - =============== Setup stage =============== - """ - # Protocol config - num_samples = len(sampled_node_ids) - num_shares = sec_agg_config[KEY_SHARE_NUMBER] - threshold = sec_agg_config[KEY_THRESHOLD] - mod_range = sec_agg_config[KEY_MOD_RANGE] - # Quantization config - clipping_range = sec_agg_config[KEY_CLIPPING_RANGE] - target_range = sec_agg_config[KEY_TARGET_RANGE] - - if LOG_EXPLAIN: - _quantized = quantize( - [np.ones(3) for _ in range(num_samples)], clipping_range, target_range - ) - print( - "\n\n################################ Introduction ################################\n" - "In the example, each client will upload a vector [1.0, 1.0, 1.0] instead of\n" - "model updates for demonstration purposes.\n" - "Client 0 is configured to drop out before uploading the masked vector.\n" - f"After quantization, the raw vectors will be:" - ) - for i in range(1, num_samples): - print(f"\t{_quantized[i]} from Client {i}") - print( - f"Numbers are rounded to integers stochastically during the quantization\n" - ", and thus not all entries are identical." - ) - print( - "The above raw vectors are hidden from the driver through adding masks.\n" - ) - print( - "########################## Secure Aggregation Start ##########################" - ) - cfg = { - KEY_STAGE: STAGE_SETUP, - KEY_SAMPLE_NUMBER: num_samples, - KEY_SHARE_NUMBER: num_shares, - KEY_THRESHOLD: threshold, - KEY_CLIPPING_RANGE: clipping_range, - KEY_TARGET_RANGE: target_range, - KEY_MOD_RANGE: mod_range, - } - # The number of shares should better be odd in the SecAgg+ protocol. - if num_samples != num_shares and num_shares & 0x1 == 0: - log(WARNING, "Number of shares in the SecAgg+ protocol should be odd.") - num_shares += 1 - - # Randomly assign secure IDs to clients - sids = [i for i in range(len(sampled_node_ids))] - random.shuffle(sids) - nid2sid = dict(zip(sampled_node_ids, sids)) - sid2nid = {sid: nid for nid, sid in nid2sid.items()} - # Build neighbour relations (node ID -> secure IDs of neighbours) - half_share = num_shares >> 1 - nid2neighbours = { - node_id: { - (nid2sid[node_id] + offset) % num_samples - for offset in range(-half_share, half_share + 1) - } - for node_id in sampled_node_ids - } - - surviving_node_ids = sampled_node_ids - if LOG_EXPLAIN: - print( - f"Sending configurations to {num_samples} clients and allocating secure IDs..." - ) - # Send setup configuration to clients - yield { - node_id: _wrap_in_task( - named_values={ - **cfg, - KEY_SECURE_ID: nid2sid[node_id], - } - ) - for node_id in surviving_node_ids - } - # Receive public keys from clients and build the dict - node_messages = yield - surviving_node_ids = [node_id for node_id in node_messages] - - if LOG_EXPLAIN: - print(f"Received public keys from {len(surviving_node_ids)} clients.") - - sid2public_keys = {} - for node_id, task in node_messages.items(): - key_dict = _get_from_task(task) - pk1, pk2 = key_dict[KEY_PUBLIC_KEY_1], key_dict[KEY_PUBLIC_KEY_2] - sid2public_keys[nid2sid[node_id]] = [pk1, pk2] - - """ - =============== Share keys stage =============== - """ - if LOG_EXPLAIN: - print(f"\nForwarding public keys...") - # Broadcast public keys to clients - yield { - node_id: _wrap_in_task( - named_values={ - KEY_STAGE: STAGE_SHARE_KEYS, - **{ - str(sid): value - for sid, value in sid2public_keys.items() - if sid in nid2neighbours[node_id] - }, - } - ) - for node_id in surviving_node_ids - } - - # Receive secret key shares from clients - node_messages = yield - surviving_node_ids = [node_id for node_id in node_messages] - if LOG_EXPLAIN: - print(f"Received encrypted key shares from {len(surviving_node_ids)} clients.") - # Build forward packet list dictionary - srcs, dsts, ciphertexts = [], [], [] - fwd_ciphertexts: Dict[int, List[bytes]] = { - nid2sid[nid]: [] for nid in surviving_node_ids - } # dest secure ID -> list of ciphertexts - fwd_srcs: Dict[int, List[bytes]] = { - sid: [] for sid in fwd_ciphertexts - } # dest secure ID -> list of src secure IDs - for node_id, task in node_messages.items(): - res_dict = _get_from_task(task) - srcs += [nid2sid[node_id]] * len(res_dict[KEY_DESTINATION_LIST]) - dsts += res_dict[KEY_DESTINATION_LIST] - ciphertexts += res_dict[KEY_CIPHERTEXT_LIST] - - for src, dst, ciphertext in zip(srcs, dsts, ciphertexts): - if dst in fwd_ciphertexts: - fwd_ciphertexts[dst].append(ciphertext) - fwd_srcs[dst].append(src) - - """ - =============== Collect masked input stage =============== - """ - - if LOG_EXPLAIN: - print(f"\nForwarding encrypted key shares and requesting masked input...") - # Send encrypted secret key shares to clients (plus model parameters) - weights = parameters_to_ndarrays(parameters) - yield { - node_id: _wrap_in_task( - named_values={ - KEY_STAGE: STAGE_COLLECT_MASKED_INPUT, - KEY_CIPHERTEXT_LIST: fwd_ciphertexts[nid2sid[node_id]], - KEY_SOURCE_LIST: fwd_srcs[nid2sid[node_id]], - KEY_PARAMETERS: [ndarray_to_bytes(arr) for arr in weights], - } - ) - for node_id in surviving_node_ids - } - # Collect masked input from clients - node_messages = yield - surviving_node_ids = [node_id for node_id in node_messages] - # Get shape of vector sent by first client - masked_vector = [np.array([0], dtype=int)] + get_zero_parameters( - [w.shape for w in weights] - ) - # Add all collected masked vectors and compuute available and dropout clients set - dead_sids = { - nid2sid[node_id] - for node_id in sampled_node_ids - if node_id not in surviving_node_ids - } - active_sids = {nid2sid[node_id] for node_id in surviving_node_ids} - if LOG_EXPLAIN: - for sid in dead_sids: - print(f"Client {sid} dropped out.") - for node_id, task in node_messages.items(): - named_values = _get_from_task(task) - client_masked_vec = named_values[KEY_MASKED_PARAMETERS] - client_masked_vec = [bytes_to_ndarray(b) for b in client_masked_vec] - if LOG_EXPLAIN: - print(f"Received {client_masked_vec[1]} from Client {nid2sid[node_id]}.") - masked_vector = parameters_addition(masked_vector, client_masked_vec) - masked_vector = parameters_mod(masked_vector, mod_range) - """ - =============== Unmask stage =============== - """ - - if LOG_EXPLAIN: - print("\nRequesting key shares to unmask the aggregate vector...") - # Send secure IDs of active and dead clients. - yield { - node_id: _wrap_in_task( - named_values={ - KEY_STAGE: STAGE_UNMASK, - KEY_DEAD_SECURE_ID_LIST: list(dead_sids & nid2neighbours[node_id]), - KEY_ACTIVE_SECURE_ID_LIST: list(active_sids & nid2neighbours[node_id]), - } - ) - for node_id in surviving_node_ids - } - # Collect key shares from clients - node_messages = yield - surviving_node_ids = [node_id for node_id in node_messages] - if LOG_EXPLAIN: - print(f"Received key shares from {len(surviving_node_ids)} clients.") - # Build collected shares dict - collected_shares_dict: Dict[int, List[bytes]] = {} - for nid in sampled_node_ids: - collected_shares_dict[nid2sid[nid]] = [] - - if len(surviving_node_ids) < threshold: - raise Exception("Not enough available clients after unmask vectors stage") - for _, task in node_messages.items(): - named_values = _get_from_task(task) - for owner_sid, share in zip( - named_values[KEY_SECURE_ID_LIST], named_values[KEY_SHARE_LIST] - ): - collected_shares_dict[owner_sid].append(share) - # Remove mask for every client who is available before ask vectors stage, - # divide vector by first element - active_sids, dead_sids = set(active_sids), set(dead_sids) - for sid, share_list in collected_shares_dict.items(): - if len(share_list) < threshold: - raise Exception( - "Not enough shares to recover secret in unmask vectors stage" - ) - secret = combine_shares(share_list) - if sid in active_sids: - # The seed for PRG is the private mask seed of an active client. - private_mask = pseudo_rand_gen( - secret, mod_range, get_parameters_shape(masked_vector) - ) - masked_vector = parameters_subtraction(masked_vector, private_mask) - else: - # The seed for PRG is the secret key 1 of a dropped client. - neighbor_list = list(nid2neighbours[sid2nid[sid]]) - neighbor_list.remove(sid) - - for neighbor_sid in neighbor_list: - shared_key = generate_shared_key( - bytes_to_private_key(secret), - bytes_to_public_key(sid2public_keys[neighbor_sid][0]), - ) - pairwise_mask = pseudo_rand_gen( - shared_key, mod_range, get_parameters_shape(masked_vector) - ) - if sid > neighbor_sid: - masked_vector = parameters_addition(masked_vector, pairwise_mask) - else: - masked_vector = parameters_subtraction(masked_vector, pairwise_mask) - recon_parameters = parameters_mod(masked_vector, mod_range) - # Divide vector by number of clients who have given us their masked vector - # i.e. those participating in final unmask vectors stage - total_weights_factor, recon_parameters = factor_extract(recon_parameters) - if LOG_EXPLAIN: - print(f"Unmasked sum of vectors (quantized): {recon_parameters[0]}") - # recon_parameters = parameters_divide(recon_parameters, total_weights_factor) - aggregated_vector = dequantize( - quantized_parameters=recon_parameters, - clipping_range=clipping_range, - target_range=target_range, - ) - aggregated_vector[0] -= (len(active_sids) - 1) * clipping_range - if LOG_EXPLAIN: - print(f"Unmasked sum of vectors (dequantized): {aggregated_vector[0]}") - print( - f"Aggregate vector using FedAvg: {aggregated_vector[0] / len(active_sids)}" - ) - print( - "########################### Secure Aggregation End ###########################\n\n" - ) - aggregated_parameters = ndarrays_to_parameters(aggregated_vector) - # Update model parameters - parameters.tensors = aggregated_parameters.tensors - parameters.tensor_type = aggregated_parameters.tensor_type - - -def _wrap_workflow_with_sec_agg( - parameters: Parameters, sampled_node_ids: List[int] -) -> Generator[Dict[int, Task], Dict[int, Task], None]: - return workflow_with_sec_agg( - parameters, sampled_node_ids, sec_agg_config=_secure_aggregation_configuration - ) diff --git a/examples/simulation-pytorch/README.md b/examples/simulation-pytorch/README.md index 11b7a3364376..93f9e1acbac7 100644 --- a/examples/simulation-pytorch/README.md +++ b/examples/simulation-pytorch/README.md @@ -1,6 +1,6 @@ # Flower Simulation example using PyTorch -This introductory example uses the simulation capabilities of Flower to simulate a large number of clients on a single machine. Take a look at the [Documentation](https://flower.dev/docs/framework/how-to-run-simulations.html) for a deep dive into how Flower simulation works. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) to download, partition and preprocess the MNIST dataset. This examples uses 100 clients by default. +This introductory example uses the simulation capabilities of Flower to simulate a large number of clients on a single machine. Take a look at the [Documentation](https://flower.ai/docs/framework/how-to-run-simulations.html) for a deep dive into how Flower simulation works. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) to download, partition and preprocess the MNIST dataset. This examples uses 100 clients by default. ## Running the example (via Jupyter Notebook) @@ -54,17 +54,13 @@ Write the command below in your terminal to install the dependencies according t pip install -r requirements.txt ``` -### Run Federated Learning Example +### Run with `start_simulation()` -```bash -# You can run the example without activating your environemnt -poetry run python sim.py +Ensure you have activated your environment then: -# Or by first activating it -poetry shell +```bash # and then run the example python sim.py -# you can exit your environment by typing "exit" ``` You can adjust the CPU/GPU resources you assign to each of your virtual clients. By default, your clients will only use 1xCPU core. For example: @@ -73,10 +69,29 @@ You can adjust the CPU/GPU resources you assign to each of your virtual clients. # Will assign 2xCPUs to each client python sim.py --num_cpus=2 -# Will assign 2xCPUs and 20% of the GPU's VRAM to each client -# This means that you can have 5 concurrent clients on each GPU +# Will assign 2xCPUs and 25% of the GPU's VRAM to each client +# This means that you can have 4 concurrent clients on each GPU # (assuming you have enough CPUs) -python sim.py --num_cpus=2 --num_gpus=0.2 +python sim.py --num_cpus=2 --num_gpus=0.25 +``` + +### Run with Flower Next (preview) + +Ensure you have activated your environment, then execute the command below. All `ClientApp` instances will run on CPU but the `ServerApp` will run on the GPU if one is available. Note that this is the case because the `Simulation Engine` only exposes certain resources to the `ClientApp` (based on the `client_resources` in `--backend-config`). + +```bash +# Run with the default backend-config. +# `--server-app` points to the `server` object in the sim.py file in this example. +# `--client-app` points to the `client` object in the sim.py file in this example. +flower-simulation --client-app=sim:client --server-app=sim:server --num-supernodes=100 +``` + +You can change the default resources assigned to each `ClientApp` by means of the `--backend-config` argument: + +```bash +# Tells the VCE to reserve 2x CPUs and 25% of available VRAM for each ClientApp +flower-simulation --client-app=sim:client --server-app=sim:server --num-supernodes=100 \ + --backend-config='{"client_resources": {"num_cpus":2, "num_gpus":0.25}}' ``` -Take a look at the [Documentation](https://flower.dev/docs/framework/how-to-run-simulations.html) for more details on how you can customise your simulation. +Take a look at the [Documentation](https://flower.ai/docs/framework/how-to-run-simulations.html) for more details on how you can customise your simulation. diff --git a/examples/simulation-pytorch/pyproject.toml b/examples/simulation-pytorch/pyproject.toml index 07918c0cd17c..5978c17f2c60 100644 --- a/examples/simulation-pytorch/pyproject.toml +++ b/examples/simulation-pytorch/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" name = "simulation-pytorch" version = "0.1.0" description = "Federated Learning Simulation with Flower and PyTorch" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" @@ -17,4 +17,3 @@ torchvision = "0.16.1" [tool.poetry.group.dev.dependencies] ipykernel = "^6.27.0" - diff --git a/examples/simulation-pytorch/sim.ipynb b/examples/simulation-pytorch/sim.ipynb index 508630cf9422..e27721a7fa5f 100644 --- a/examples/simulation-pytorch/sim.ipynb +++ b/examples/simulation-pytorch/sim.ipynb @@ -30,7 +30,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We will be using the _simulation_ mode in Flower, which allows you to run a large number of clients without the overheads of manually managing devices. This is achieved via the [Virtual Client Engine](https://flower.dev/docs/framework/how-to-run-simulations.html) in Flower. With simulation, you can dynamically scale your experiments whether you run the code on your laptop, a machine with a single GPU, a server with multiple GPUs os even on a cluster with multiple servers. The `Virtual Client Engine` handles everything transparently and it allows you to specify how many resources (e.g. CPU cores, GPU VRAM) should be assigned to each virtual client." + "We will be using the _simulation_ mode in Flower, which allows you to run a large number of clients without the overheads of manually managing devices. This is achieved via the [Virtual Client Engine](https://flower.ai/docs/framework/how-to-run-simulations.html) in Flower. With simulation, you can dynamically scale your experiments whether you run the code on your laptop, a machine with a single GPU, a server with multiple GPUs os even on a cluster with multiple servers. The `Virtual Client Engine` handles everything transparently and it allows you to specify how many resources (e.g. CPU cores, GPU VRAM) should be assigned to each virtual client." ] }, { @@ -178,7 +178,7 @@ "\n", "To start designing a Federated Learning pipeline we need to meet one of the key properties in FL: each client has its own data partition. To accomplish this with the MNIST dataset, we are going to generate N random partitions, where N is the total number of clients in our FL system.\n", "\n", - "We can use [Flower Datasets](https://flower.dev/docs/datasets/) to effortlessly obtain an off-the-shelf partitioned dataset or partition one that isn't pre-partitioned. Let's choose MNIST." + "We can use [Flower Datasets](https://flower.ai/docs/datasets/) to effortlessly obtain an off-the-shelf partitioned dataset or partition one that isn't pre-partitioned. Let's choose MNIST." ] }, { @@ -290,7 +290,7 @@ " self.model.to(self.device) # send model to device\n", "\n", " def set_parameters(self, parameters):\n", - " \"\"\"With the model paramters received from the server,\n", + " \"\"\"With the model parameters received from the server,\n", " overwrite the uninitialise model in this class with them.\"\"\"\n", "\n", " params_dict = zip(self.model.state_dict().keys(), parameters)\n", @@ -509,7 +509,7 @@ " valloader = DataLoader(valset.with_transform(apply_transforms), batch_size=32)\n", "\n", " # Create and return client\n", - " return FlowerClient(trainloader, valloader)\n", + " return FlowerClient(trainloader, valloader).to_client()\n", "\n", " return client_fn\n", "\n", @@ -605,11 +605,11 @@ "\n", "Get all resources you need!\n", "\n", - "* **[DOCS]** Our complete documenation: https://flower.dev/docs/\n", - "* **[Examples]** All Flower examples: https://flower.dev/docs/examples/\n", + "* **[DOCS]** Our complete documenation: https://flower.ai/docs/\n", + "* **[Examples]** All Flower examples: https://flower.ai/docs/examples/\n", "* **[VIDEO]** Our Youtube channel: https://www.youtube.com/@flowerlabs\n", "\n", - "Don't forget to join our Slack channel: https://flower.dev/join-slack/\n" + "Don't forget to join our Slack channel: https://flower.ai/join-slack/\n" ] } ], diff --git a/examples/simulation-pytorch/sim.py b/examples/simulation-pytorch/sim.py index 68d9426e83ab..ca9e6f0e8366 100644 --- a/examples/simulation-pytorch/sim.py +++ b/examples/simulation-pytorch/sim.py @@ -29,9 +29,9 @@ default=0.0, help="Ratio of GPU memory to assign to a virtual client", ) -parser.add_argument("--num_rounds", type=int, default=10, help="Number of FL rounds.") NUM_CLIENTS = 100 +NUM_ROUNDS = 10 # Flower client, adapted from Pytorch quickstart example @@ -104,7 +104,7 @@ def client_fn(cid: str) -> fl.client.Client: valset = valset.with_transform(apply_transforms) # Create and return client - return FlowerClient(trainset, valset) + return FlowerClient(trainset, valset).to_client() return client_fn @@ -167,28 +167,36 @@ def evaluate( return evaluate +# Download MNIST dataset and partition it +mnist_fds = FederatedDataset(dataset="mnist", partitioners={"train": NUM_CLIENTS}) +centralized_testset = mnist_fds.load_full("test") + +# Configure the strategy +strategy = fl.server.strategy.FedAvg( + fraction_fit=0.1, # Sample 10% of available clients for training + fraction_evaluate=0.05, # Sample 5% of available clients for evaluation + min_available_clients=10, + on_fit_config_fn=fit_config, + evaluate_metrics_aggregation_fn=weighted_average, # Aggregate federated metrics + evaluate_fn=get_evaluate_fn(centralized_testset), # Global evaluation function +) + +# ClientApp for Flower-Next +client = fl.client.ClientApp( + client_fn=get_client_fn(mnist_fds), +) + +# ServerApp for Flower-Next +server = fl.server.ServerApp( + config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS), + strategy=strategy, +) + + def main(): # Parse input arguments args = parser.parse_args() - # Download MNIST dataset and partition it - mnist_fds = FederatedDataset(dataset="mnist", partitioners={"train": NUM_CLIENTS}) - centralized_testset = mnist_fds.load_full("test") - - # Configure the strategy - strategy = fl.server.strategy.FedAvg( - fraction_fit=0.1, # Sample 10% of available clients for training - fraction_evaluate=0.05, # Sample 5% of available clients for evaluation - min_fit_clients=10, # Never sample less than 10 clients for training - min_evaluate_clients=5, # Never sample less than 5 clients for evaluation - min_available_clients=int( - NUM_CLIENTS * 0.75 - ), # Wait until at least 75 clients are available - on_fit_config_fn=fit_config, - evaluate_metrics_aggregation_fn=weighted_average, # Aggregate federated metrics - evaluate_fn=get_evaluate_fn(centralized_testset), # Global evaluation function - ) - # Resources to be assigned to each virtual client client_resources = { "num_cpus": args.num_cpus, @@ -200,7 +208,7 @@ def main(): client_fn=get_client_fn(mnist_fds), num_clients=NUM_CLIENTS, client_resources=client_resources, - config=fl.server.ServerConfig(num_rounds=args.num_rounds), + config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS), strategy=strategy, actor_kwargs={ "on_actor_init_fn": disable_progress_bar # disable tqdm on each actor/process spawning virtual clients diff --git a/examples/simulation-tensorflow/README.md b/examples/simulation-tensorflow/README.md index f0d94f343d37..917d7b34c7af 100644 --- a/examples/simulation-tensorflow/README.md +++ b/examples/simulation-tensorflow/README.md @@ -1,6 +1,6 @@ # Flower Simulation example using TensorFlow/Keras -This introductory example uses the simulation capabilities of Flower to simulate a large number of clients on a single machine. Take a look at the [Documentation](https://flower.dev/docs/framework/how-to-run-simulations.html) for a deep dive into how Flower simulation works. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) to download, partition and preprocess the MNIST dataset. This examples uses 100 clients by default. +This introductory example uses the simulation capabilities of Flower to simulate a large number of clients on a single machine. Take a look at the [Documentation](https://flower.ai/docs/framework/how-to-run-simulations.html) for a deep dive into how Flower simulation works. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) to download, partition and preprocess the MNIST dataset. This examples uses 100 clients by default. ## Running the example (via Jupyter Notebook) @@ -53,29 +53,46 @@ Write the command below in your terminal to install the dependencies according t pip install -r requirements.txt ``` -### Run Federated Learning Example +### Run with `start_simulation()` -```bash -# You can run the example without activating your environemnt -poetry run python sim.py +Ensure you have activated your environment then: -# Or by first activating it -poetry shell +```bash # and then run the example python sim.py -# you can exit your environment by typing "exit" ``` -You can adjust the CPU/GPU resources you assign to each of your virtual clients. By default, your clients will only use 1xCPU core. For example: +You can adjust the CPU/GPU resources you assign to each of your virtual clients. By default, your clients will only use 2xCPU core. For example: ```bash # Will assign 2xCPUs to each client python sim.py --num_cpus=2 -# Will assign 2xCPUs and 20% of the GPU's VRAM to each client -# This means that you can have 5 concurrent clients on each GPU +# Will assign 2xCPUs and 25% of the GPU's VRAM to each client +# This means that you can have 4 concurrent clients on each GPU # (assuming you have enough CPUs) -python sim.py --num_cpus=2 --num_gpus=0.2 +python sim.py --num_cpus=2 --num_gpus=0.25 +``` + +Because TensorFlow by default maps all the available VRAM, we need to [enable GPU memory growth](https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth), see how it is done in the example (`sim.py`) for both the "main" process (where the server/strategy runs) and for the clients (using the `actor_kwargs`) + +### Run with Flower Next (preview) + +Ensure you have activated your environment, then execute the command below. All `ClientApp` instances will run on CPU but the `ServerApp` will run on the GPU if one is available. Note that this is the case because the `Simulation Engine` only exposes certain resources to the `ClientApp` (based on the `client_resources` in `--backend-config`). For TensorFlow simulations, it is desirable to make use of TF's [memory growth](https://www.tensorflow.org/api_docs/python/tf/config/experimental/set_memory_growth) feature. You can enable that easily with the `--enable-tf-gpu-growth` flag. + +```bash +# Run with the default backend-config. +# `--server-app` points to the `server` object in the sim.py file in this example. +# `--client-app` points to the `client` object in the sim.py file in this example. +flower-simulation --client-app=sim:client --server-app=sim:server --num-supernodes=100 --enable-tf-gpu-growth +``` + +You can change the default resources assigned to each `ClientApp` using the `--backend-config` argument. + +```bash +# Tells the VCE to reserve 2x CPUs and 25% of available VRAM for each ClientApp +flower-simulation --client-app=sim:client --server-app=sim:server --num-supernodes=100 \ + --backend-config='{"client_resources": {"num_cpus":2, "num_gpus":0.25}}' --enable-tf-gpu-growth ``` -Take a look at the [Documentation](https://flower.dev/docs/framework/how-to-run-simulations.html) for more details on how you can customise your simulation. +Take a look at the [Documentation](https://flower.ai/docs/framework/how-to-run-simulations.html) for more details on how you can customise your simulation. diff --git a/examples/simulation-tensorflow/pyproject.toml b/examples/simulation-tensorflow/pyproject.toml index f2e7bd3006c0..ad8cc2032b2d 100644 --- a/examples/simulation-tensorflow/pyproject.toml +++ b/examples/simulation-tensorflow/pyproject.toml @@ -6,11 +6,11 @@ build-backend = "poetry.core.masonry.api" name = "simulation-tensorflow" version = "0.1.0" description = "Federated Learning Simulation with Flower and Tensorflow" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" flwr = { extras = ["simulation"], version = ">=1.0,<2.0" } flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" } -tensorflow = {version = "^2.9.1, !=2.11.1", markers="platform_machine == 'x86_64'"} -tensorflow-macos = {version = "^2.9.1, !=2.11.1", markers="sys_platform == 'darwin' and platform_machine == 'arm64'"} +tensorflow = { version = "^2.9.1, !=2.11.1", markers = "platform_machine == 'x86_64'" } +tensorflow-macos = { version = "^2.9.1, !=2.11.1", markers = "sys_platform == 'darwin' and platform_machine == 'arm64'" } diff --git a/examples/simulation-tensorflow/sim.ipynb b/examples/simulation-tensorflow/sim.ipynb index 575b437018f3..9acfba99237c 100644 --- a/examples/simulation-tensorflow/sim.ipynb +++ b/examples/simulation-tensorflow/sim.ipynb @@ -189,7 +189,7 @@ " )\n", "\n", " # Create and return client\n", - " return FlowerClient(trainset, valset)\n", + " return FlowerClient(trainset, valset).to_client()\n", "\n", " return client_fn\n", "\n", @@ -232,7 +232,7 @@ "\n", "Flower comes with a number of built-in strategies, but we can also use our own strategy implementations to customize nearly all aspects of the federated learning approach. For this example, we use the built-in `FedAvg` implementation and customize it using a few basic parameters. The last step is the actual call to `start_simulation` which - you guessed it - actually starts the simulation.\n", "\n", - "We can use [Flower Datasets](https://flower.dev/docs/datasets/) to effortlessly obtain an off-the-shelf partitioned dataset or partition one that isn't pre-partitioned. Let's choose MNIST." + "We can use [Flower Datasets](https://flower.ai/docs/datasets/) to effortlessly obtain an off-the-shelf partitioned dataset or partition one that isn't pre-partitioned. Let's choose MNIST." ] }, { @@ -323,11 +323,11 @@ "\n", "Get all resources you need!\n", "\n", - "* **[DOCS]** Our complete documenation: https://flower.dev/docs/\n", - "* **[Examples]** All Flower examples: https://flower.dev/docs/examples/\n", + "* **[DOCS]** Our complete documenation: https://flower.ai/docs/\n", + "* **[Examples]** All Flower examples: https://flower.ai/docs/examples/\n", "* **[VIDEO]** Our Youtube channel: https://www.youtube.com/@flowerlabs\n", "\n", - "Don't forget to join our Slack channel: https://flower.dev/join-slack/" + "Don't forget to join our Slack channel: https://flower.ai/join-slack/" ] } ], diff --git a/examples/simulation-tensorflow/sim.py b/examples/simulation-tensorflow/sim.py index 490e25fe8c8d..2a19e131fe79 100644 --- a/examples/simulation-tensorflow/sim.py +++ b/examples/simulation-tensorflow/sim.py @@ -1,5 +1,4 @@ import os -import math import argparse from typing import Dict, List, Tuple @@ -29,9 +28,9 @@ default=0.0, help="Ratio of GPU memory to assign to a virtual client", ) -parser.add_argument("--num_rounds", type=int, default=10, help="Number of FL rounds.") NUM_CLIENTS = 100 +NUM_ROUNDS = 10 VERBOSE = 0 @@ -94,7 +93,7 @@ def client_fn(cid: str) -> fl.client.Client: ) # Create and return client - return FlowerClient(trainset, valset) + return FlowerClient(trainset, valset).to_client() return client_fn @@ -129,30 +128,39 @@ def evaluate( return evaluate +# Download MNIST dataset and partition it +mnist_fds = FederatedDataset(dataset="mnist", partitioners={"train": NUM_CLIENTS}) +# Get the whole test set for centralised evaluation +centralized_testset = mnist_fds.load_full("test").to_tf_dataset( + columns="image", label_cols="label", batch_size=64 +) + +# Create FedAvg strategy +strategy = fl.server.strategy.FedAvg( + fraction_fit=0.1, # Sample 10% of available clients for training + fraction_evaluate=0.05, # Sample 5% of available clients for evaluation + min_fit_clients=10, # Never sample less than 10 clients for training + evaluate_metrics_aggregation_fn=weighted_average, # aggregates federated metrics + evaluate_fn=get_evaluate_fn(centralized_testset), # global evaluation function +) + + +# ClientApp for Flower-Next +client = fl.client.ClientApp( + client_fn=get_client_fn(mnist_fds), +) + +# ServerApp for Flower-Next +server = fl.server.ServerApp( + config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS), + strategy=strategy, +) + + def main() -> None: # Parse input arguments args = parser.parse_args() - # Download MNIST dataset and partition it - mnist_fds = FederatedDataset(dataset="mnist", partitioners={"train": NUM_CLIENTS}) - # Get the whole test set for centralised evaluation - centralized_testset = mnist_fds.load_full("test").to_tf_dataset( - columns="image", label_cols="label", batch_size=64 - ) - - # Create FedAvg strategy - strategy = fl.server.strategy.FedAvg( - fraction_fit=0.1, # Sample 10% of available clients for training - fraction_evaluate=0.05, # Sample 5% of available clients for evaluation - min_fit_clients=10, # Never sample less than 10 clients for training - min_evaluate_clients=5, # Never sample less than 5 clients for evaluation - min_available_clients=int( - NUM_CLIENTS * 0.75 - ), # Wait until at least 75 clients are available - evaluate_metrics_aggregation_fn=weighted_average, # aggregates federated metrics - evaluate_fn=get_evaluate_fn(centralized_testset), # global evaluation function - ) - # With a dictionary, you tell Flower's VirtualClientEngine that each # client needs exclusive access to these many resources in order to run client_resources = { @@ -164,7 +172,7 @@ def main() -> None: fl.simulation.start_simulation( client_fn=get_client_fn(mnist_fds), num_clients=NUM_CLIENTS, - config=fl.server.ServerConfig(num_rounds=args.num_rounds), + config=fl.server.ServerConfig(NUM_ROUNDS), strategy=strategy, client_resources=client_resources, actor_kwargs={ diff --git a/examples/sklearn-logreg-mnist/README.md b/examples/sklearn-logreg-mnist/README.md index ee3cdfc9768e..12b1a5e3bc1a 100644 --- a/examples/sklearn-logreg-mnist/README.md +++ b/examples/sklearn-logreg-mnist/README.md @@ -1,7 +1,7 @@ # Flower Example using scikit-learn This example of Flower uses `scikit-learn`'s `LogisticRegression` model to train a federated learning system. It will help you understand how to adapt Flower for use with `scikit-learn`. -Running this example in itself is quite easy. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) to download, partition and preprocess the MNIST dataset. +Running this example in itself is quite easy. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) to download, partition and preprocess the MNIST dataset. ## Project Setup @@ -62,13 +62,13 @@ Now you are ready to start the Flower clients which will participate in the lear Start client 1 in the first terminal: ```shell -python3 client.py --node-id 0 # or any integer in {0-9} +python3 client.py --partition-id 0 # or any integer in {0-9} ``` Start client 2 in the second terminal: ```shell -python3 client.py --node-id 1 # or any integer in {0-9} +python3 client.py --partition-id 1 # or any integer in {0-9} ``` Alternatively, you can run all of it in one shell as follows: diff --git a/examples/sklearn-logreg-mnist/client.py b/examples/sklearn-logreg-mnist/client.py index a5fcaba87409..1e9349df1acc 100644 --- a/examples/sklearn-logreg-mnist/client.py +++ b/examples/sklearn-logreg-mnist/client.py @@ -13,14 +13,14 @@ parser = argparse.ArgumentParser(description="Flower") parser.add_argument( - "--node-id", + "--partition-id", type=int, choices=range(0, N_CLIENTS), required=True, help="Specifies the artificial data partition", ) args = parser.parse_args() - partition_id = args.node_id + partition_id = args.partition_id # Load the partition data fds = FederatedDataset(dataset="mnist", partitioners={"train": N_CLIENTS}) @@ -62,4 +62,6 @@ def evaluate(self, parameters, config): # type: ignore return loss, len(X_test), {"accuracy": accuracy} # Start Flower client - fl.client.start_numpy_client(server_address="0.0.0.0:8080", client=MnistClient()) + fl.client.start_client( + server_address="0.0.0.0:8080", client=MnistClient().to_client() + ) diff --git a/examples/sklearn-logreg-mnist/pyproject.toml b/examples/sklearn-logreg-mnist/pyproject.toml index 8ea49fe187a2..58cc5ca4a02e 100644 --- a/examples/sklearn-logreg-mnist/pyproject.toml +++ b/examples/sklearn-logreg-mnist/pyproject.toml @@ -7,8 +7,8 @@ name = "sklearn-mnist" version = "0.1.0" description = "Federated learning with scikit-learn and Flower" authors = [ - "The Flower Authors ", - "Kaushik Amar Das " + "The Flower Authors ", + "Kaushik Amar Das ", ] [tool.poetry.dependencies] diff --git a/examples/sklearn-logreg-mnist/run.sh b/examples/sklearn-logreg-mnist/run.sh index 48cee1b41b74..f770ca05f8f4 100755 --- a/examples/sklearn-logreg-mnist/run.sh +++ b/examples/sklearn-logreg-mnist/run.sh @@ -8,7 +8,7 @@ sleep 3 # Sleep for 3s to give the server enough time to start for i in $(seq 0 1); do echo "Starting client $i" - python client.py --node-id "${i}" & + python client.py --partition-id "${i}" & done # This will allow you to use CTRL+C to stop all background processes diff --git a/examples/vertical-fl/pyproject.toml b/examples/vertical-fl/pyproject.toml index 14771c70062f..19dcd0e7a842 100644 --- a/examples/vertical-fl/pyproject.toml +++ b/examples/vertical-fl/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" name = "vertical-fl" version = "0.1.0" description = "PyTorch Vertical FL with Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" diff --git a/examples/vit-finetune/README.md b/examples/vit-finetune/README.md new file mode 100644 index 000000000000..ac1652acf02d --- /dev/null +++ b/examples/vit-finetune/README.md @@ -0,0 +1,94 @@ +# Federated finetuning of a ViT + +This example shows how to use Flower's Simulation Engine to federate the finetuning of a Vision Transformer ([ViT-Base-16](https://pytorch.org/vision/main/models/generated/torchvision.models.vit_b_16.html#torchvision.models.vit_b_16)) that has been pretrained on ImageNet. To keep things simple we'll be finetuning it to [Oxford Flower-102](https://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html) datasset, creating 20 partitions using [Flower Datasets](https://flower.ai/docs/datasets/). We'll be finetuning just the exit `head` of the ViT, this means that the training is not that costly and each client requires just ~1GB of VRAM (for a batch size of 32 images). + +## Running the example + +If you haven't cloned the Flower repository already you might want to clone code example and discard the rest. We prepared a single-line command that you can copy into your shell which will checkout the example for you: + +```shell +git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/vit-finetune . && rm -rf flower && cd vit-finetune +``` + +This will create a new directory called `vit-finetune` containing the following files: + +``` +-- README.md <- Your're reading this right now +-- main.py <- Main file that launches the simulation +-- client.py <- Contains Flower client code and ClientApp +-- server.py <- Contains Flower server code and ServerApp +-- model.py <- Defines model and train/eval functions +-- dataset.py <- Downloads, partitions and processes dataset +-- pyproject.toml <- Example dependencies, installable using Poetry +-- requirements.txt <- Example dependencies, installable using pip +``` + +### Installing Dependencies + +Project dependencies (such as `torch` and `flwr`) are defined in `pyproject.toml` and `requirements.txt`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)) or [pip](https://pip.pypa.io/en/latest/development/), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences. + +#### Poetry + +```shell +poetry install +poetry shell +``` + +#### pip + +With an activated environemnt, install the dependencies for this example: + +```shell +pip install -r requirements.txt +``` + +### Run with `start_simulation()` + +Running the example is quite straightforward. You can control the number of rounds `--num-rounds` (which defaults to 20). + +```bash +python main.py +``` + +![](_static/central_evaluation.png) + +Running the example as-is on an RTX 3090Ti should take ~15s/round running 5 clients in parallel (plus the _global model_ during centralized evaluation stages) in a single GPU. Note that more clients could fit in VRAM, but since the GPU utilization is high (99%-100%) we are probably better off not doing that (at least in this case). + +You can adjust the `client_resources` passed to `start_simulation()` so more/less clients run at the same time in the GPU. Take a look at the [Documentation](https://flower.ai/docs/framework/how-to-run-simulations.html) for more details on how you can customise your simulation. + +```bash ++---------------------------------------------------------------------------------------+ +| NVIDIA-SMI 535.161.07 Driver Version: 535.161.07 CUDA Version: 12.2 | +|-----------------------------------------+----------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+======================+======================| +| 0 NVIDIA GeForce RTX 3090 Ti Off | 00000000:0B:00.0 Off | Off | +| 44% 74C P2 441W / 450W | 7266MiB / 24564MiB | 100% Default | +| | | N/A | ++-----------------------------------------+----------------------+----------------------+ + ++---------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=======================================================================================| +| 0 N/A N/A 173812 C python 1966MiB | +| 0 N/A N/A 174510 C ray::ClientAppActor.run 1056MiB | +| 0 N/A N/A 174512 C ray::ClientAppActor.run 1056MiB | +| 0 N/A N/A 174513 C ray::ClientAppActor.run 1056MiB | +| 0 N/A N/A 174514 C ray::ClientAppActor.run 1056MiB | +| 0 N/A N/A 174516 C ray::ClientAppActor.run 1056MiB | ++---------------------------------------------------------------------------------------+ +``` + +### Run with Flower Next (preview) + +```bash +flower-simulation \ + --client-app=client:app \ + --server-app=server:app \ + --num-supernodes=20 \ + --backend-config='{"client_resources": {"num_cpus":4, "num_gpus":0.25}}' +``` diff --git a/examples/vit-finetune/_static/central_evaluation.png b/examples/vit-finetune/_static/central_evaluation.png new file mode 100644 index 000000000000..d0d53ba353a1 Binary files /dev/null and b/examples/vit-finetune/_static/central_evaluation.png differ diff --git a/examples/vit-finetune/client.py b/examples/vit-finetune/client.py new file mode 100644 index 000000000000..68d98926feeb --- /dev/null +++ b/examples/vit-finetune/client.py @@ -0,0 +1,82 @@ +import torch +from torch.utils.data import DataLoader + +import flwr +from flwr.client import NumPyClient +from dataset import apply_transforms, get_dataset_with_partitions +from model import get_model, set_parameters, train + + +class FedViTClient(NumPyClient): + + def __init__(self, trainset): + + self.trainset = trainset + self.model = get_model() + + # Determine device + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self.model.to(self.device) # send model to device + + def set_for_finetuning(self): + """Freeze all parameter except those in the final head. + + Only output MLP will be updated by the client and therefore, the only part of + the model that will be federated (hence, communicated back to the server for + aggregation.) + """ + + # Disable gradients for everything + self.model.requires_grad_(False) + # Now enable just for output head + self.model.heads.requires_grad_(True) + + def get_parameters(self, config): + """Get locally updated parameters.""" + finetune_layers = self.model.heads + return [val.cpu().numpy() for _, val in finetune_layers.state_dict().items()] + + def fit(self, parameters, config): + set_parameters(self.model, parameters) + + # Get some info from the config + # Get batchsize and LR set from server + batch_size = config["batch_size"] + lr = config["lr"] + + trainloader = DataLoader( + self.trainset, batch_size=batch_size, num_workers=2, shuffle=True + ) + + # Set optimizer + optimizer = torch.optim.Adam(self.model.parameters(), lr=lr) + # Train locally + avg_train_loss = train( + self.model, trainloader, optimizer, epochs=1, device=self.device + ) + # Return locally-finetuned part of the model + return ( + self.get_parameters(config={}), + len(trainloader.dataset), + {"train_loss": avg_train_loss}, + ) + + +# Downloads and partition dataset +federated_ox_flowers, _ = get_dataset_with_partitions(num_partitions=20) + + +def client_fn(cid: str): + """Return a FedViTClient that trains with the cid-th data partition.""" + + trainset_for_this_client = federated_ox_flowers.load_partition(int(cid), "train") + + trainset = trainset_for_this_client.with_transform(apply_transforms) + + return FedViTClient(trainset).to_client() + + +# To be used with Flower Next +app = flwr.client.ClientApp( + client_fn=client_fn, +) diff --git a/examples/vit-finetune/dataset.py b/examples/vit-finetune/dataset.py new file mode 100644 index 000000000000..c11eb7c19712 --- /dev/null +++ b/examples/vit-finetune/dataset.py @@ -0,0 +1,52 @@ +from torchvision.transforms import ( + Compose, + Normalize, + ToTensor, + RandomResizedCrop, + Resize, + CenterCrop, +) + +from flwr_datasets import FederatedDataset + + +def get_dataset_with_partitions(num_partitions: int): + """Get Oxford Flowers datasets and partition it. + + Return partitioned dataset as well as the whole test set. + """ + + # Get Oxford Flowers-102 and divide it into 20 IID partitions + ox_flowers_fds = FederatedDataset( + dataset="nelorth/oxford-flowers", partitioners={"train": num_partitions} + ) + + centralized_testset = ox_flowers_fds.load_full("test") + return ox_flowers_fds, centralized_testset + + +def apply_eval_transforms(batch): + """Apply a very standard set of image transforms.""" + transforms = Compose( + [ + Resize((256, 256)), + CenterCrop((224, 224)), + ToTensor(), + Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + batch["image"] = [transforms(img) for img in batch["image"]] + return batch + + +def apply_transforms(batch): + """Apply a very standard set of image transforms.""" + transforms = Compose( + [ + RandomResizedCrop((224, 224)), + ToTensor(), + Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + batch["image"] = [transforms(img) for img in batch["image"]] + return batch diff --git a/examples/vit-finetune/main.py b/examples/vit-finetune/main.py new file mode 100644 index 000000000000..1257246304a1 --- /dev/null +++ b/examples/vit-finetune/main.py @@ -0,0 +1,58 @@ +import argparse + +import flwr as fl +import matplotlib.pyplot as plt + +from server import strategy +from client import client_fn + +parser = argparse.ArgumentParser( + description="Finetuning of a ViT with Flower Simulation." +) + +parser.add_argument( + "--num-rounds", + type=int, + default=20, + help="Number of rounds.", +) + + +def main(): + + args = parser.parse_args() + + # To control the degree of parallelism + # With default settings in this example, + # each client should take just ~1GB of VRAM. + client_resources = { + "num_cpus": 4, + "num_gpus": 0.2, + } + + # Launch simulation + history = fl.simulation.start_simulation( + client_fn=client_fn, + num_clients=20, + client_resources=client_resources, + config=fl.server.ServerConfig(num_rounds=args.num_rounds), + strategy=strategy, + ) + + print(history) + + # Basic plotting + global_accuracy_centralised = history.metrics_centralized["accuracy"] + round = [int(data[0]) for data in global_accuracy_centralised] + acc = [100.0 * data[1] for data in global_accuracy_centralised] + plt.plot(round, acc) + plt.xticks(round) + plt.grid() + plt.ylabel("Accuracy (%)") + plt.xlabel("Round") + plt.title("Federated finetuning of ViT for Flowers-102") + plt.savefig("central_evaluation.png") + + +if __name__ == "__main__": + main() diff --git a/examples/vit-finetune/model.py b/examples/vit-finetune/model.py new file mode 100644 index 000000000000..ca7dc1cd9864 --- /dev/null +++ b/examples/vit-finetune/model.py @@ -0,0 +1,71 @@ +from collections import OrderedDict + +import torch +from torchvision.models import vit_b_16, ViT_B_16_Weights + + +def get_model(): + """Return a pretrained ViT with all layers frozen except output head.""" + + # Instantiate a pre-trained ViT-B on ImageNet + model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1) + + # We're going to federated the finetuning of this model + # using the Oxford Flowers-102 dataset. One easy way to achieve + # this is by re-initializing the output block of the ViT so it + # outputs 102 clases instead of the default 1k + in_features = model.heads[-1].in_features + model.heads[-1] = torch.nn.Linear(in_features, 102) + + # Disable gradients for everything + model.requires_grad_(False) + # Now enable just for output head + model.heads.requires_grad_(True) + + return model + + +def set_parameters(model, parameters): + """Apply the parameters to the model. + + Recall this example only federates the head of the ViT so that's the only part of + the model we need to load. + """ + finetune_layers = model.heads + params_dict = zip(finetune_layers.state_dict().keys(), parameters) + state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) + finetune_layers.load_state_dict(state_dict, strict=True) + + +def train(net, trainloader, optimizer, epochs, device): + """Train the model on the training set.""" + criterion = torch.nn.CrossEntropyLoss() + net.train() + avg_loss = 0 + # A very standard training loop for image classification + for _ in range(epochs): + for batch in trainloader: + images, labels = batch["image"].to(device), batch["label"].to(device) + optimizer.zero_grad() + loss = criterion(net(images), labels) + avg_loss += loss.item() / labels.shape[0] + loss.backward() + optimizer.step() + + return avg_loss / len(trainloader) + + +def test(net, testloader, device: str): + """Validate the network on the entire test set.""" + criterion = torch.nn.CrossEntropyLoss() + correct, loss = 0, 0.0 + net.eval() + with torch.no_grad(): + for data in testloader: + images, labels = data["image"].to(device), data["label"].to(device) + outputs = net(images) + loss += criterion(outputs, labels).item() + _, predicted = torch.max(outputs.data, 1) + correct += (predicted == labels).sum().item() + accuracy = correct / len(testloader.dataset) + return loss, accuracy diff --git a/examples/vit-finetune/pyproject.toml b/examples/vit-finetune/pyproject.toml new file mode 100644 index 000000000000..d014d6b6fb2a --- /dev/null +++ b/examples/vit-finetune/pyproject.toml @@ -0,0 +1,17 @@ +[build-system] +requires = ["poetry-core>=1.4.0"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "vit-finetune" +version = "0.1.0" +description = "FL finetuning of a Vision Transformer with Flower." +authors = ["The Flower Authors "] + +[tool.poetry.dependencies] +python = ">=3.8,<3.11" +flwr = { extras = ["simulation"], version = ">=1.0,<2.0" } +flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" } +torch = "2.2.1" +torchvision = "0.17.1" +matplotlib = "3.8.3" diff --git a/examples/vit-finetune/requirements.txt b/examples/vit-finetune/requirements.txt new file mode 100644 index 000000000000..3692be0d6c2c --- /dev/null +++ b/examples/vit-finetune/requirements.txt @@ -0,0 +1,5 @@ +flwr[simulation]>=1.0, <2.0 +flwr-datasets[vision]>=0.0.2, <1.0.0 +matplotlib==3.8.3 +torch==2.2.1 +torchvision==0.17.1 \ No newline at end of file diff --git a/examples/vit-finetune/server.py b/examples/vit-finetune/server.py new file mode 100644 index 000000000000..698bcd45cece --- /dev/null +++ b/examples/vit-finetune/server.py @@ -0,0 +1,61 @@ +import torch +from datasets import Dataset +from torch.utils.data import DataLoader +import flwr as fl + +from dataset import apply_eval_transforms, get_dataset_with_partitions +from model import get_model, set_parameters, test + + +def fit_config(server_round: int): + """Return a configuration with static batch size and (local) epochs.""" + config = { + "lr": 0.01, # Learning rate used by clients + "batch_size": 32, # Batch size to use by clients during fit() + } + return config + + +def get_evaluate_fn( + centralized_testset: Dataset, +): + """Return an evaluation function for centralized evaluation.""" + + def evaluate(server_round, parameters, config): + """Use the entire Oxford Flowers-102 test set for evaluation.""" + + # Determine device + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + model = get_model() + set_parameters(model, parameters) + model.to(device) + + # Apply transform to dataset + testset = centralized_testset.with_transform(apply_eval_transforms) + + testloader = DataLoader(testset, batch_size=128) + # Run evaluation + loss, accuracy = test(model, testloader, device=device) + + return loss, {"accuracy": accuracy} + + return evaluate + + +# Downloads and partition dataset +_, centralized_testset = get_dataset_with_partitions(num_partitions=20) + +# Configure the strategy +strategy = fl.server.strategy.FedAvg( + fraction_fit=0.5, # Sample 50% of available clients for training each round + fraction_evaluate=0.0, # No federated evaluation + on_fit_config_fn=fit_config, + evaluate_fn=get_evaluate_fn(centralized_testset), # Global evaluation function +) + +# To be used with Flower Next +app = fl.server.ServerApp( + config=fl.server.ServerConfig(num_rounds=3), + strategy=strategy, +) diff --git a/examples/whisper-federated-finetuning/README.md b/examples/whisper-federated-finetuning/README.md index e89a09519fed..ddebe51247b2 100644 --- a/examples/whisper-federated-finetuning/README.md +++ b/examples/whisper-federated-finetuning/README.md @@ -110,7 +110,7 @@ An overview of the FL pipeline built with Flower for this example is illustrated 3. Once on-site training is completed, each client sends back the (now updated) classification head to the Flower server. 4. The Flower server aggregates (via FedAvg) the classification heads in order to obtain a new _global_ classification head. This head will be shared with clients in the next round. -Flower supports two ways of doing Federated Learning: simulated and non-simulated FL. The former, managed by the [`VirtualClientEngine`](https://flower.dev/docs/framework/how-to-run-simulations.html), allows you to run large-scale workloads in a system-aware manner, that scales with the resources available on your system (whether it is a laptop, a desktop with a single GPU, or a cluster of GPU servers). The latter is better suited for settings where clients are unique devices (e.g. a server, a smart device, etc). This example shows you how to use both. +Flower supports two ways of doing Federated Learning: simulated and non-simulated FL. The former, managed by the [`VirtualClientEngine`](https://flower.ai/docs/framework/how-to-run-simulations.html), allows you to run large-scale workloads in a system-aware manner, that scales with the resources available on your system (whether it is a laptop, a desktop with a single GPU, or a cluster of GPU servers). The latter is better suited for settings where clients are unique devices (e.g. a server, a smart device, etc). This example shows you how to use both. ### Preparing the dataset @@ -147,7 +147,7 @@ INFO flwr 2023-11-08 14:03:57,557 | app.py:229 | app_fit: metrics_centralized {' With just 5 FL rounds, the global model should be reaching ~95% validation accuracy. A test accuracy of 97% can be reached with 10 rounds of FL training using the default hyperparameters. On an RTX 3090Ti, each round takes ~20-30s depending on the amount of data the clients selected in a round have. -Take a look at the [Documentation](https://flower.dev/docs/framework/how-to-run-simulations.html) for more details on how you can customize your simulation. +Take a look at the [Documentation](https://flower.ai/docs/framework/how-to-run-simulations.html) for more details on how you can customize your simulation. ### Federated Finetuning (non-simulated) diff --git a/examples/whisper-federated-finetuning/client.py b/examples/whisper-federated-finetuning/client.py index 2bfeadfbdae6..d3bb217933f8 100644 --- a/examples/whisper-federated-finetuning/client.py +++ b/examples/whisper-federated-finetuning/client.py @@ -146,13 +146,13 @@ def client_fn(cid: str): return WhisperFlowerClient( full_train_dataset, num_classes, disable_tqdm, compile - ) + ).to_client() return client_fn -def run_client(): - """Run clinet.""" +def main(): + """Run client.""" # Parse input arguments args = parser.parse_args() @@ -174,10 +174,11 @@ def run_client(): client_data_path=CLIENT_DATA, ) - fl.client.start_numpy_client( - server_address=f"{args.server_address}:8080", client=client_fn(args.cid) + fl.client.start_client( + server_address=f"{args.server_address}:8080", + client=client_fn(args.cid), ) if __name__ == "__main__": - run_client() + main() diff --git a/examples/whisper-federated-finetuning/pyproject.toml b/examples/whisper-federated-finetuning/pyproject.toml index dd5578b8b3d0..27a89578c5a0 100644 --- a/examples/whisper-federated-finetuning/pyproject.toml +++ b/examples/whisper-federated-finetuning/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" name = "whisper-flower" version = "0.1.0" description = "On-device Federated Downstreaming for Speech Classification" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" @@ -16,4 +16,4 @@ tokenizers = "0.13.3" datasets = "2.14.6" soundfile = "0.12.1" librosa = "0.10.1" -# this example was tested with pytorch 2.1.0 \ No newline at end of file +# this example was tested with pytorch 2.1.0 diff --git a/examples/whisper-federated-finetuning/requirements.txt b/examples/whisper-federated-finetuning/requirements.txt index eb4a5d7eb47b..f16b3d6993ce 100644 --- a/examples/whisper-federated-finetuning/requirements.txt +++ b/examples/whisper-federated-finetuning/requirements.txt @@ -3,5 +3,4 @@ tokenizers==0.13.3 datasets==2.14.6 soundfile==0.12.1 librosa==0.10.1 -flwr==1.5.0 -ray==2.6.3 \ No newline at end of file +flwr[simulation]>=1.0, <2.0 \ No newline at end of file diff --git a/examples/xgboost-comprehensive/README.md b/examples/xgboost-comprehensive/README.md index 11c4c3f9a08b..dc6d7e3872d6 100644 --- a/examples/xgboost-comprehensive/README.md +++ b/examples/xgboost-comprehensive/README.md @@ -1,7 +1,7 @@ # Flower Example using XGBoost (Comprehensive) This example demonstrates a comprehensive federated learning setup using Flower with XGBoost. -We use [HIGGS](https://archive.ics.uci.edu/dataset/280/higgs) dataset to perform a binary classification task. +We use [HIGGS](https://archive.ics.uci.edu/dataset/280/higgs) dataset to perform a binary classification task. This examples uses [Flower Datasets](https://flower.ai/docs/datasets/) to retrieve, partition and preprocess the data for each Flower client. It differs from the [xgboost-quickstart](https://github.com/adap/flower/tree/main/examples/xgboost-quickstart) example in the following ways: - Arguments parsers of server and clients for hyperparameters selection. @@ -10,6 +10,32 @@ It differs from the [xgboost-quickstart](https://github.com/adap/flower/tree/mai - Customised partitioner type (uniform, linear, square, exponential). - Centralised/distributed evaluation. - Bagging/cyclic training methods. +- You can run it with Flower Simulation + +## Training Strategies + +This example provides two training strategies, **bagging aggregation** and **cyclic training**. + +### Bagging Aggregation + +Bagging (bootstrap) aggregation is an ensemble meta-algorithm in machine learning, +used for enhancing the stability and accuracy of machine learning algorithms. +Here, we leverage this algorithm for XGBoost trees. + +Specifically, each client is treated as a bootstrap by random subsampling (data partitioning in FL). +At each FL round, all clients boost a number of trees (in this example, 1 tree) based on the local bootstrap samples. +Then, the clients' trees are aggregated on the server, and concatenates them to the global model from previous round. +The aggregated tree ensemble is regarded as a new global model. + +This way, let's consider a scenario with M clients. +Given FL round R, the bagging models consist of (M * R) trees. + +### Cyclic Training + +Cyclic XGBoost training performs FL in a client-by-client fashion. +Instead of aggregating multiple clients, +there is only one single client participating in the training per round in the cyclic training scenario. +The trained local XGBoost trees will be passed to the next client as an initialised model for next round's boosting. ## Project Setup @@ -26,7 +52,10 @@ This will create a new directory called `xgboost-comprehensive` containing the f -- server.py <- Defines the server-side logic -- client.py <- Defines the client-side logic -- dataset.py <- Defines the functions of data loading and partitioning --- utils.py <- Defines the arguments parser for clients and server +-- utils.py <- Defines the arguments parser and hyper-parameters +-- client_utils.py <- Defines the client utility functions +-- server_utils.py <- Defines the server utility functions +-- sim.py <- Example of using Flower simulation -- run_bagging.sh <- Commands to run bagging experiments -- run_cyclic.sh <- Commands to run cyclic experiments -- pyproject.toml <- Example dependencies (if you use Poetry) @@ -47,7 +76,7 @@ poetry shell Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command: ```shell -poetry run python3 -c "import flwr" +poetry run python -c "import flwr" ``` If you don't see any errors you're good to go! @@ -62,44 +91,76 @@ pip install -r requirements.txt ## Run Federated Learning with XGBoost and Flower +You can run this example in two ways: either by manually launching the server, and then several clients that connect to it; or by launching a Flower simulation. Both run the same workload, yielding identical results. The former is ideal for deployments on different machines, while the latter makes it easy to simulate large client cohorts in a resource-aware manner. You can read more about how Flower Simulation works in the [Documentation](https://flower.ai/docs/framework/how-to-run-simulations.html). The commands shown below assume you have activated your environment (if you decide to use Poetry, you can activate it via `poetry shell`). + +### Independent Client/Server Setup + We have two scripts to run bagging and cyclic (client-by-client) experiments. The included `run_bagging.sh` or `run_cyclic.sh` will start the Flower server (using `server.py`), sleep for 15 seconds to ensure that the server is up, and then start 5 Flower clients (using `client.py`) with a small subset of the data from exponential partition distribution. + You can simply start everything in a terminal as follows: ```shell -poetry run ./run_bagging.sh +./run_bagging.sh ``` Or ```shell -poetry run ./run_cyclic.sh +./run_cyclic.sh +``` + +The script starts processes in the background so that you don't have to open six terminal windows. + +You can also run the example without the scripts. First, launch the server: + +```bash +python server.py --train-method=bagging/cyclic --pool-size=N --num-clients-per-round=N +``` + +Then run at least two clients (each on a new terminal or computer in your network) passing different `PARTITION_ID` and all using the same `N` (denoting the total number of clients or data partitions): + +```bash +python client.py --train-method=bagging/cyclic --partition-id=PARTITION_ID --num-partitions=N +``` + +### Flower Simulation Setup + +We also provide an example code (`sim.py`) to use the simulation capabilities of Flower to simulate federated XGBoost training on either a single machine or a cluster of machines. With default arguments, each client will use 2 CPUs. + +To run bagging aggregation with 5 clients for 30 rounds evaluated on centralised test set: + +```shell +python sim.py --train-method=bagging --pool-size=5 --num-clients-per-round=5 --num-rounds=30 --centralised-eval ``` -The script starts processes in the background so that you don't have to open eleven terminal windows. -If you experiment with the code example and something goes wrong, simply using `CTRL + C` on Linux (or `CMD + C` on macOS) wouldn't normally kill all these processes, -which is why the script ends with `trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM EXIT` and `wait`. -This simply allows you to stop the experiment using `CTRL + C` (or `CMD + C`). -If you change the script and anything goes wrong you can still use `killall python` (or `killall python3`) -to kill all background processes (or a more specific command if you have other Python processes running that you don't want to kill). +To run cyclic training with 5 clients for 30 rounds evaluated on centralised test set: -You can also manually run `poetry run python3 server.py --train-method=bagging/cyclic --pool-size=N --num-clients-per-round=N` -and `poetry run python3 client.py --train-method=bagging/cyclic --node-id=NODE_ID --num-partitions=N` for as many clients as you want, -but you have to make sure that each command is run in a different terminal window (or a different computer on the network). +```shell +python sim.py --train-method=cyclic --pool-size=5 --num-rounds=30 --centralised-eval-client +``` In addition, we provide more options to customise the experimental settings, including data partitioning and centralised/distributed evaluation (see `utils.py`). -Look at the [code](https://github.com/adap/flower/tree/main/examples/xgboost-comprehensive) -and [tutorial](https://flower.dev/docs/framework/tutorial-quickstart-xgboost.html) for a detailed explanation. +Check the [tutorial](https://flower.ai/docs/framework/tutorial-quickstart-xgboost.html) for a detailed explanation. ### Expected Experimental Results #### Bagging aggregation experiment -![](_static/xgboost_flower_auc.png) +![](_static/xgboost_flower_auc_bagging.png) -The figure above shows the centralised tested AUC performance over FL rounds on 4 experimental settings. +The figure above shows the centralised tested AUC performance over FL rounds with bagging aggregation strategy on 4 experimental settings. One can see that all settings obtain stable performance boost over FL rounds (especially noticeable at the start of training). -As expected, uniform client distribution shows higher AUC values (beyond 83% at the end) than square/exponential setup. -Feel free to explore more interesting experiments by yourself! +As expected, uniform client distribution shows higher AUC values than square/exponential setup. + +#### Cyclic training experiment + +![](_static/xgboost_flower_auc_cyclic.png) + +This figure shows the cyclic training results on centralised test set. +The models with cyclic training requires more rounds to converge +because only a single client participate in the training per round. + +Feel free to explore more interesting experiments by yourself ! diff --git a/examples/xgboost-comprehensive/_static/xgboost_flower_auc.png b/examples/xgboost-comprehensive/_static/xgboost_flower_auc.png deleted file mode 100644 index e6a4bfb83250..000000000000 Binary files a/examples/xgboost-comprehensive/_static/xgboost_flower_auc.png and /dev/null differ diff --git a/examples/xgboost-comprehensive/_static/xgboost_flower_auc_bagging.png b/examples/xgboost-comprehensive/_static/xgboost_flower_auc_bagging.png new file mode 100644 index 000000000000..e192df214471 Binary files /dev/null and b/examples/xgboost-comprehensive/_static/xgboost_flower_auc_bagging.png differ diff --git a/examples/xgboost-comprehensive/_static/xgboost_flower_auc_cyclic.png b/examples/xgboost-comprehensive/_static/xgboost_flower_auc_cyclic.png new file mode 100644 index 000000000000..731d0fc3fbbc Binary files /dev/null and b/examples/xgboost-comprehensive/_static/xgboost_flower_auc_cyclic.png differ diff --git a/examples/xgboost-comprehensive/client.py b/examples/xgboost-comprehensive/client.py index ff7a4adf7977..66daed449fd5 100644 --- a/examples/xgboost-comprehensive/client.py +++ b/examples/xgboost-comprehensive/client.py @@ -1,21 +1,9 @@ import warnings from logging import INFO -import xgboost as xgb import flwr as fl from flwr_datasets import FederatedDataset from flwr.common.logger import log -from flwr.common import ( - Code, - EvaluateIns, - EvaluateRes, - FitIns, - FitRes, - GetParametersIns, - GetParametersRes, - Parameters, - Status, -) from dataset import ( instantiate_partitioner, @@ -23,7 +11,8 @@ transform_dataset_to_dmatrix, resplit, ) -from utils import client_args_parser, BST_PARAMS +from utils import client_args_parser, BST_PARAMS, NUM_LOCAL_ROUND +from client_utils import XgbClient warnings.filterwarnings("ignore", category=UserWarning) @@ -32,15 +21,13 @@ # Parse arguments for experimental settings args = client_args_parser() -# Load (HIGGS) dataset and conduct partitioning -num_partitions = args.num_partitions - -# Partitioner type is chosen from ["uniform", "linear", "square", "exponential"] -partitioner_type = args.partitioner_type +# Train method (bagging or cyclic) +train_method = args.train_method -# Instantiate partitioner +# Load (HIGGS) dataset and conduct partitioning +# Instantiate partitioner from ["uniform", "linear", "square", "exponential"] partitioner = instantiate_partitioner( - partitioner_type=partitioner_type, num_partitions=num_partitions + partitioner_type=args.partitioner_type, num_partitions=args.num_partitions ) fds = FederatedDataset( dataset="jxie/higgs", @@ -48,10 +35,9 @@ resplitter=resplit, ) -# Load the partition for this `node_id` +# Load the partition for this `partition_id` log(INFO, "Loading partition...") -node_id = args.node_id -partition = fds.load_partition(node_id=node_id, split="train") +partition = fds.load_partition(partition_id=args.partition_id, split="train") partition.set_format("numpy") if args.centralised_eval: @@ -63,10 +49,8 @@ num_val = valid_data.shape[0] else: # Train/test splitting - SEED = args.seed - test_fraction = args.test_fraction train_data, valid_data, num_train, num_val = train_test_split( - partition, test_fraction=test_fraction, seed=SEED + partition, test_fraction=args.test_fraction, seed=args.seed ) # Reformat data to DMatrix for xgboost @@ -74,101 +58,25 @@ train_dmatrix = transform_dataset_to_dmatrix(train_data) valid_dmatrix = transform_dataset_to_dmatrix(valid_data) - # Hyper-parameters for xgboost training -num_local_round = 1 +num_local_round = NUM_LOCAL_ROUND params = BST_PARAMS - -# Define Flower client -class XgbClient(fl.client.Client): - def __init__(self): - self.bst = None - self.config = None - - def get_parameters(self, ins: GetParametersIns) -> GetParametersRes: - _ = (self, ins) - return GetParametersRes( - status=Status( - code=Code.OK, - message="OK", - ), - parameters=Parameters(tensor_type="", tensors=[]), - ) - - def _local_boost(self): - # Update trees based on local training data. - for i in range(num_local_round): - self.bst.update(train_dmatrix, self.bst.num_boosted_rounds()) - - # Bagging: extract the last N=num_local_round trees for sever aggregation - # Cyclic: return the entire model - bst = ( - self.bst[ - self.bst.num_boosted_rounds() - - num_local_round : self.bst.num_boosted_rounds() - ] - if args.train_method == "bagging" - else self.bst - ) - - return bst - - def fit(self, ins: FitIns) -> FitRes: - if not self.bst: - # First round local training - log(INFO, "Start training at round 1") - bst = xgb.train( - params, - train_dmatrix, - num_boost_round=num_local_round, - evals=[(valid_dmatrix, "validate"), (train_dmatrix, "train")], - ) - self.config = bst.save_config() - self.bst = bst - else: - for item in ins.parameters.tensors: - global_model = bytearray(item) - - # Load global model into booster - self.bst.load_model(global_model) - self.bst.load_config(self.config) - - bst = self._local_boost() - - local_model = bst.save_raw("json") - local_model_bytes = bytes(local_model) - - return FitRes( - status=Status( - code=Code.OK, - message="OK", - ), - parameters=Parameters(tensor_type="", tensors=[local_model_bytes]), - num_examples=num_train, - metrics={}, - ) - - def evaluate(self, ins: EvaluateIns) -> EvaluateRes: - eval_results = self.bst.eval_set( - evals=[(valid_dmatrix, "valid")], - iteration=self.bst.num_boosted_rounds() - 1, - ) - auc = round(float(eval_results.split("\t")[1].split(":")[1]), 4) - - global_round = ins.config["global_round"] - log(INFO, f"AUC = {auc} at round {global_round}") - - return EvaluateRes( - status=Status( - code=Code.OK, - message="OK", - ), - loss=0.0, - num_examples=num_val, - metrics={"AUC": auc}, - ) - +# Setup learning rate +if args.train_method == "bagging" and args.scaled_lr: + new_lr = params["eta"] / args.num_partitions + params.update({"eta": new_lr}) # Start Flower client -fl.client.start_client(server_address="127.0.0.1:8080", client=XgbClient()) +fl.client.start_client( + server_address="127.0.0.1:8080", + client=XgbClient( + train_dmatrix, + valid_dmatrix, + num_train, + num_val, + num_local_round, + params, + train_method, + ), +) diff --git a/examples/xgboost-comprehensive/client_utils.py b/examples/xgboost-comprehensive/client_utils.py new file mode 100644 index 000000000000..d2e07677ef97 --- /dev/null +++ b/examples/xgboost-comprehensive/client_utils.py @@ -0,0 +1,126 @@ +from logging import INFO +import xgboost as xgb + +import flwr as fl +from flwr.common.logger import log +from flwr.common import ( + Code, + EvaluateIns, + EvaluateRes, + FitIns, + FitRes, + GetParametersIns, + GetParametersRes, + Parameters, + Status, +) + + +class XgbClient(fl.client.Client): + def __init__( + self, + train_dmatrix, + valid_dmatrix, + num_train, + num_val, + num_local_round, + params, + train_method, + ): + self.train_dmatrix = train_dmatrix + self.valid_dmatrix = valid_dmatrix + self.num_train = num_train + self.num_val = num_val + self.num_local_round = num_local_round + self.params = params + self.train_method = train_method + + def get_parameters(self, ins: GetParametersIns) -> GetParametersRes: + _ = (self, ins) + return GetParametersRes( + status=Status( + code=Code.OK, + message="OK", + ), + parameters=Parameters(tensor_type="", tensors=[]), + ) + + def _local_boost(self, bst_input): + # Update trees based on local training data. + for i in range(self.num_local_round): + bst_input.update(self.train_dmatrix, bst_input.num_boosted_rounds()) + + # Bagging: extract the last N=num_local_round trees for sever aggregation + # Cyclic: return the entire model + bst = ( + bst_input[ + bst_input.num_boosted_rounds() + - self.num_local_round : bst_input.num_boosted_rounds() + ] + if self.train_method == "bagging" + else bst_input + ) + + return bst + + def fit(self, ins: FitIns) -> FitRes: + global_round = int(ins.config["global_round"]) + if global_round == 1: + # First round local training + bst = xgb.train( + self.params, + self.train_dmatrix, + num_boost_round=self.num_local_round, + evals=[(self.valid_dmatrix, "validate"), (self.train_dmatrix, "train")], + ) + else: + bst = xgb.Booster(params=self.params) + for item in ins.parameters.tensors: + global_model = bytearray(item) + + # Load global model into booster + bst.load_model(global_model) + + # Local training + bst = self._local_boost(bst) + + # Save model + local_model = bst.save_raw("json") + local_model_bytes = bytes(local_model) + + return FitRes( + status=Status( + code=Code.OK, + message="OK", + ), + parameters=Parameters(tensor_type="", tensors=[local_model_bytes]), + num_examples=self.num_train, + metrics={}, + ) + + def evaluate(self, ins: EvaluateIns) -> EvaluateRes: + # Load global model + bst = xgb.Booster(params=self.params) + for para in ins.parameters.tensors: + para_b = bytearray(para) + bst.load_model(para_b) + + # Run evaluation + eval_results = bst.eval_set( + evals=[(self.valid_dmatrix, "valid")], + iteration=bst.num_boosted_rounds() - 1, + ) + auc = round(float(eval_results.split("\t")[1].split(":")[1]), 4) + + global_round = ins.config["global_round"] + log(INFO, f"AUC = {auc} at round {global_round}") + + return EvaluateRes( + status=Status( + code=Code.OK, + message="OK", + ), + loss=0.0, + num_examples=self.num_val, + metrics={"AUC": auc}, + ) diff --git a/examples/xgboost-comprehensive/dataset.py b/examples/xgboost-comprehensive/dataset.py index bcf2e00b30af..94959925f833 100644 --- a/examples/xgboost-comprehensive/dataset.py +++ b/examples/xgboost-comprehensive/dataset.py @@ -39,12 +39,18 @@ def train_test_split(partition: Dataset, test_fraction: float, seed: int): def transform_dataset_to_dmatrix(data: Union[Dataset, DatasetDict]) -> xgb.core.DMatrix: """Transform dataset to DMatrix format for xgboost.""" - x = data["inputs"] - y = data["label"] + x, y = separate_xy(data) new_data = xgb.DMatrix(x, label=y) return new_data +def separate_xy(data: Union[Dataset, DatasetDict]): + """Return outputs of x (data) and y (labels) .""" + x = data["inputs"] + y = data["label"] + return x, y + + def resplit(dataset: DatasetDict) -> DatasetDict: """Increase the quantity of centralised test samples from 500K to 1M.""" return DatasetDict( diff --git a/examples/xgboost-comprehensive/pyproject.toml b/examples/xgboost-comprehensive/pyproject.toml index bbfbb4134b8d..b257801cb420 100644 --- a/examples/xgboost-comprehensive/pyproject.toml +++ b/examples/xgboost-comprehensive/pyproject.toml @@ -6,10 +6,10 @@ build-backend = "poetry.core.masonry.api" name = "xgboost-comprehensive" version = "0.1.0" description = "Federated XGBoost with Flower (comprehensive)" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" -flwr-nightly = ">=1.0,<2.0" +flwr-nightly = { extras = ["simulation"], version = ">=1.7.0,<2.0" } flwr-datasets = ">=0.0.2,<1.0.0" xgboost = ">=2.0.0,<3.0.0" diff --git a/examples/xgboost-comprehensive/requirements.txt b/examples/xgboost-comprehensive/requirements.txt index c37ac2b6ad6d..b5b1d83bcdd1 100644 --- a/examples/xgboost-comprehensive/requirements.txt +++ b/examples/xgboost-comprehensive/requirements.txt @@ -1,3 +1,3 @@ -flwr-nightly>=1.0, <2.0 +flwr[simulation]>=1.7.0, <2.0 flwr-datasets>=0.0.2, <1.0.0 xgboost>=2.0.0, <3.0.0 diff --git a/examples/xgboost-comprehensive/run_bagging.sh b/examples/xgboost-comprehensive/run_bagging.sh index 7920f6bf5e55..a6300b781a06 100755 --- a/examples/xgboost-comprehensive/run_bagging.sh +++ b/examples/xgboost-comprehensive/run_bagging.sh @@ -3,12 +3,12 @@ set -e cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/ echo "Starting server" -python3 server.py --pool-size=5 --num-rounds=50 --num-clients-per-round=5 --centralised-eval & -sleep 15 # Sleep for 15s to give the server enough time to start +python3 server.py --pool-size=5 --num-rounds=30 --num-clients-per-round=5 --centralised-eval & +sleep 30 # Sleep for 30s to give the server enough time to start for i in `seq 0 4`; do echo "Starting client $i" - python3 client.py --node-id=$i --num-partitions=5 --partitioner-type=exponential & + python3 client.py --partition-id=$i --num-partitions=5 --partitioner-type=exponential & done # Enable CTRL+C to stop all background processes diff --git a/examples/xgboost-comprehensive/run_cyclic.sh b/examples/xgboost-comprehensive/run_cyclic.sh index 47e09fd8faef..258bdf2fe0d8 100755 --- a/examples/xgboost-comprehensive/run_cyclic.sh +++ b/examples/xgboost-comprehensive/run_cyclic.sh @@ -8,7 +8,7 @@ sleep 15 # Sleep for 15s to give the server enough time to start for i in `seq 0 4`; do echo "Starting client $i" - python3 client.py --node-id=$i --train-method=cyclic --num-partitions=5 --partitioner-type=exponential --centralised-eval & + python3 client.py --partition-id=$i --train-method=cyclic --num-partitions=5 --partitioner-type=exponential --centralised-eval & done # Enable CTRL+C to stop all background processes diff --git a/examples/xgboost-comprehensive/server.py b/examples/xgboost-comprehensive/server.py index 1cf4ba79fa50..2fecbcc65853 100644 --- a/examples/xgboost-comprehensive/server.py +++ b/examples/xgboost-comprehensive/server.py @@ -1,19 +1,19 @@ import warnings -from typing import Dict, List, Optional from logging import INFO -import xgboost as xgb import flwr as fl from flwr.common.logger import log -from flwr.common import Parameters, Scalar from flwr_datasets import FederatedDataset -from flwr.server.strategy import FedXgbBagging -from flwr.server.strategy import FedXgbCyclic -from flwr.server.client_proxy import ClientProxy -from flwr.server.criterion import Criterion -from flwr.server.client_manager import SimpleClientManager - -from utils import server_args_parser, BST_PARAMS +from flwr.server.strategy import FedXgbBagging, FedXgbCyclic + +from utils import server_args_parser +from server_utils import ( + eval_config, + fit_config, + evaluate_metrics_aggregation, + get_evaluate_fn, + CyclicClientManager, +) from dataset import resplit, transform_dataset_to_dmatrix @@ -34,97 +34,11 @@ fds = FederatedDataset( dataset="jxie/higgs", partitioners={"train": 20}, resplitter=resplit ) + log(INFO, "Loading centralised test set...") test_set = fds.load_full("test") test_set.set_format("numpy") test_dmatrix = transform_dataset_to_dmatrix(test_set) -# Hyper-parameters used for initialisation -params = BST_PARAMS - - -def eval_config(rnd: int) -> Dict[str, str]: - """Return a configuration with global epochs.""" - config = { - "global_round": str(rnd), - } - return config - - -def evaluate_metrics_aggregation(eval_metrics): - """Return an aggregated metric (AUC) for evaluation.""" - total_num = sum([num for num, _ in eval_metrics]) - auc_aggregated = ( - sum([metrics["AUC"] * num for num, metrics in eval_metrics]) / total_num - ) - metrics_aggregated = {"AUC": auc_aggregated} - return metrics_aggregated - - -def get_evaluate_fn(test_data): - """Return a function for centralised evaluation.""" - - def evaluate_fn( - server_round: int, parameters: Parameters, config: Dict[str, Scalar] - ): - # If at the first round, skip the evaluation - if server_round == 0: - return 0, {} - else: - bst = xgb.Booster(params=params) - for para in parameters.tensors: - para_b = bytearray(para) - - # Load global model - bst.load_model(para_b) - # Run evaluation - eval_results = bst.eval_set( - evals=[(test_data, "valid")], - iteration=bst.num_boosted_rounds() - 1, - ) - auc = round(float(eval_results.split("\t")[1].split(":")[1]), 4) - log(INFO, f"AUC = {auc} at round {server_round}") - - return 0, {"AUC": auc} - - return evaluate_fn - - -class CyclicClientManager(SimpleClientManager): - """Provides a cyclic client selection rule.""" - - def sample( - self, - num_clients: int, - min_num_clients: Optional[int] = None, - criterion: Optional[Criterion] = None, - ) -> List[ClientProxy]: - """Sample a number of Flower ClientProxy instances.""" - - # Block until at least num_clients are connected. - if min_num_clients is None: - min_num_clients = num_clients - self.wait_for(min_num_clients) - - # Sample clients which meet the criterion - available_cids = list(self.clients) - if criterion is not None: - available_cids = [ - cid for cid in available_cids if criterion.select(self.clients[cid]) - ] - - if num_clients > len(available_cids): - log( - INFO, - "Sampling failed: number of available clients" - " (%s) is less than number of requested clients (%s).", - len(available_cids), - num_clients, - ) - return [] - - # Return all available clients - return [self.clients[cid] for cid in available_cids] - # Define strategy if train_method == "bagging": @@ -137,9 +51,10 @@ def sample( min_evaluate_clients=num_evaluate_clients if not centralised_eval else 0, fraction_evaluate=1.0 if not centralised_eval else 0.0, on_evaluate_config_fn=eval_config, - evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation - if not centralised_eval - else None, + on_fit_config_fn=fit_config, + evaluate_metrics_aggregation_fn=( + evaluate_metrics_aggregation if not centralised_eval else None + ), ) else: # Cyclic training @@ -149,6 +64,7 @@ def sample( fraction_evaluate=1.0, evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation, on_evaluate_config_fn=eval_config, + on_fit_config_fn=fit_config, ) # Start Flower server diff --git a/examples/xgboost-comprehensive/server_utils.py b/examples/xgboost-comprehensive/server_utils.py new file mode 100644 index 000000000000..35a31bd9adac --- /dev/null +++ b/examples/xgboost-comprehensive/server_utils.py @@ -0,0 +1,101 @@ +from typing import Dict, List, Optional +from logging import INFO +import xgboost as xgb +from flwr.common.logger import log +from flwr.common import Parameters, Scalar +from flwr.server.client_manager import SimpleClientManager +from flwr.server.client_proxy import ClientProxy +from flwr.server.criterion import Criterion +from utils import BST_PARAMS + + +def eval_config(rnd: int) -> Dict[str, str]: + """Return a configuration with global epochs.""" + config = { + "global_round": str(rnd), + } + return config + + +def fit_config(rnd: int) -> Dict[str, str]: + """Return a configuration with global epochs.""" + config = { + "global_round": str(rnd), + } + return config + + +def evaluate_metrics_aggregation(eval_metrics): + """Return an aggregated metric (AUC) for evaluation.""" + total_num = sum([num for num, _ in eval_metrics]) + auc_aggregated = ( + sum([metrics["AUC"] * num for num, metrics in eval_metrics]) / total_num + ) + metrics_aggregated = {"AUC": auc_aggregated} + return metrics_aggregated + + +def get_evaluate_fn(test_data): + """Return a function for centralised evaluation.""" + + def evaluate_fn( + server_round: int, parameters: Parameters, config: Dict[str, Scalar] + ): + # If at the first round, skip the evaluation + if server_round == 0: + return 0, {} + else: + bst = xgb.Booster(params=BST_PARAMS) + for para in parameters.tensors: + para_b = bytearray(para) + + # Load global model + bst.load_model(para_b) + # Run evaluation + eval_results = bst.eval_set( + evals=[(test_data, "valid")], + iteration=bst.num_boosted_rounds() - 1, + ) + auc = round(float(eval_results.split("\t")[1].split(":")[1]), 4) + log(INFO, f"AUC = {auc} at round {server_round}") + + return 0, {"AUC": auc} + + return evaluate_fn + + +class CyclicClientManager(SimpleClientManager): + """Provides a cyclic client selection rule.""" + + def sample( + self, + num_clients: int, + min_num_clients: Optional[int] = None, + criterion: Optional[Criterion] = None, + ) -> List[ClientProxy]: + """Sample a number of Flower ClientProxy instances.""" + + # Block until at least num_clients are connected. + if min_num_clients is None: + min_num_clients = num_clients + self.wait_for(min_num_clients) + + # Sample clients which meet the criterion + available_cids = list(self.clients) + if criterion is not None: + available_cids = [ + cid for cid in available_cids if criterion.select(self.clients[cid]) + ] + + if num_clients > len(available_cids): + log( + INFO, + "Sampling failed: number of available clients" + " (%s) is less than number of requested clients (%s).", + len(available_cids), + num_clients, + ) + return [] + + # Return all available clients + return [self.clients[cid] for cid in available_cids] diff --git a/examples/xgboost-comprehensive/sim.py b/examples/xgboost-comprehensive/sim.py new file mode 100644 index 000000000000..b72b23931929 --- /dev/null +++ b/examples/xgboost-comprehensive/sim.py @@ -0,0 +1,188 @@ +import warnings +from logging import INFO +import xgboost as xgb +from tqdm import tqdm + +import flwr as fl +from flwr_datasets import FederatedDataset +from flwr.common.logger import log +from flwr.server.strategy import FedXgbBagging, FedXgbCyclic + +from dataset import ( + instantiate_partitioner, + train_test_split, + transform_dataset_to_dmatrix, + separate_xy, + resplit, +) +from utils import ( + sim_args_parser, + NUM_LOCAL_ROUND, + BST_PARAMS, +) +from server_utils import ( + eval_config, + fit_config, + evaluate_metrics_aggregation, + get_evaluate_fn, + CyclicClientManager, +) +from client_utils import XgbClient + + +warnings.filterwarnings("ignore", category=UserWarning) + + +def get_client_fn( + train_data_list, valid_data_list, train_method, params, num_local_round +): + """Return a function to construct a client. + + The VirtualClientEngine will execute this function whenever a client is sampled by + the strategy to participate. + """ + + def client_fn(cid: str) -> fl.client.Client: + """Construct a FlowerClient with its own dataset partition.""" + x_train, y_train = train_data_list[int(cid)][0] + x_valid, y_valid = valid_data_list[int(cid)][0] + + # Reformat data to DMatrix + train_dmatrix = xgb.DMatrix(x_train, label=y_train) + valid_dmatrix = xgb.DMatrix(x_valid, label=y_valid) + + # Fetch the number of examples + num_train = train_data_list[int(cid)][1] + num_val = valid_data_list[int(cid)][1] + + # Create and return client + return XgbClient( + train_dmatrix, + valid_dmatrix, + num_train, + num_val, + num_local_round, + params, + train_method, + ) + + return client_fn + + +def main(): + # Parse arguments for experimental settings + args = sim_args_parser() + + # Load (HIGGS) dataset and conduct partitioning + partitioner = instantiate_partitioner( + partitioner_type=args.partitioner_type, num_partitions=args.pool_size + ) + fds = FederatedDataset( + dataset="jxie/higgs", + partitioners={"train": partitioner}, + resplitter=resplit, + ) + + # Load centralised test set + if args.centralised_eval or args.centralised_eval_client: + log(INFO, "Loading centralised test set...") + test_data = fds.load_full("test") + test_data.set_format("numpy") + num_test = test_data.shape[0] + test_dmatrix = transform_dataset_to_dmatrix(test_data) + + # Load partitions and reformat data to DMatrix for xgboost + log(INFO, "Loading client local partitions...") + train_data_list = [] + valid_data_list = [] + + # Load and process all client partitions. This upfront cost is amortized soon + # after the simulation begins since clients wont need to preprocess their partition. + for partition_id in tqdm(range(args.pool_size), desc="Extracting client partition"): + # Extract partition for client with partition_id + partition = fds.load_partition(partition_id=partition_id, split="train") + partition.set_format("numpy") + + if args.centralised_eval_client: + # Use centralised test set for evaluation + train_data = partition + num_train = train_data.shape[0] + x_test, y_test = separate_xy(test_data) + valid_data_list.append(((x_test, y_test), num_test)) + else: + # Train/test splitting + train_data, valid_data, num_train, num_val = train_test_split( + partition, test_fraction=args.test_fraction, seed=args.seed + ) + x_valid, y_valid = separate_xy(valid_data) + valid_data_list.append(((x_valid, y_valid), num_val)) + + x_train, y_train = separate_xy(train_data) + train_data_list.append(((x_train, y_train), num_train)) + + # Define strategy + if args.train_method == "bagging": + # Bagging training + strategy = FedXgbBagging( + evaluate_function=( + get_evaluate_fn(test_dmatrix) if args.centralised_eval else None + ), + fraction_fit=(float(args.num_clients_per_round) / args.pool_size), + min_fit_clients=args.num_clients_per_round, + min_available_clients=args.pool_size, + min_evaluate_clients=( + args.num_evaluate_clients if not args.centralised_eval else 0 + ), + fraction_evaluate=1.0 if not args.centralised_eval else 0.0, + on_evaluate_config_fn=eval_config, + on_fit_config_fn=fit_config, + evaluate_metrics_aggregation_fn=( + evaluate_metrics_aggregation if not args.centralised_eval else None + ), + ) + else: + # Cyclic training + strategy = FedXgbCyclic( + fraction_fit=1.0, + min_available_clients=args.pool_size, + fraction_evaluate=1.0, + evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation, + on_evaluate_config_fn=eval_config, + on_fit_config_fn=fit_config, + ) + + # Resources to be assigned to each virtual client + # In this example we use CPU by default + client_resources = { + "num_cpus": args.num_cpus_per_client, + "num_gpus": 0.0, + } + + # Hyper-parameters for xgboost training + num_local_round = NUM_LOCAL_ROUND + params = BST_PARAMS + + # Setup learning rate + if args.train_method == "bagging" and args.scaled_lr: + new_lr = params["eta"] / args.pool_size + params.update({"eta": new_lr}) + + # Start simulation + fl.simulation.start_simulation( + client_fn=get_client_fn( + train_data_list, + valid_data_list, + args.train_method, + params, + num_local_round, + ), + num_clients=args.pool_size, + client_resources=client_resources, + config=fl.server.ServerConfig(num_rounds=args.num_rounds), + strategy=strategy, + client_manager=CyclicClientManager() if args.train_method == "cyclic" else None, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/xgboost-comprehensive/utils.py b/examples/xgboost-comprehensive/utils.py index 8acdbbb88a7e..abc100da1ade 100644 --- a/examples/xgboost-comprehensive/utils.py +++ b/examples/xgboost-comprehensive/utils.py @@ -1,6 +1,8 @@ import argparse +# Hyper-parameters for xgboost training +NUM_LOCAL_ROUND = 1 BST_PARAMS = { "objective": "binary:logistic", "eta": 0.1, # Learning rate @@ -35,10 +37,10 @@ def client_args_parser(): help="Partitioner types.", ) parser.add_argument( - "--node-id", + "--partition-id", default=0, type=int, - help="Node ID used for the current client.", + help="Partition ID used for the current client.", ) parser.add_argument( "--seed", default=42, type=int, help="Seed used for train/test splitting." @@ -52,7 +54,12 @@ def client_args_parser(): parser.add_argument( "--centralised-eval", action="store_true", - help="Conduct centralised evaluation (True), or client evaluation on hold-out data (False).", + help="Conduct evaluation on centralised test set (True), or on hold-out data (False).", + ) + parser.add_argument( + "--scaled-lr", + action="store_true", + help="Perform scaled learning rate based on the number of clients (True).", ) args = parser.parse_args() @@ -96,3 +103,78 @@ def server_args_parser(): args = parser.parse_args() return args + + +def sim_args_parser(): + """Parse arguments to define experimental settings on server side.""" + parser = argparse.ArgumentParser() + + parser.add_argument( + "--train-method", + default="bagging", + type=str, + choices=["bagging", "cyclic"], + help="Training methods selected from bagging aggregation or cyclic training.", + ) + + # Server side + parser.add_argument( + "--pool-size", default=5, type=int, help="Number of total clients." + ) + parser.add_argument( + "--num-rounds", default=30, type=int, help="Number of FL rounds." + ) + parser.add_argument( + "--num-clients-per-round", + default=5, + type=int, + help="Number of clients participate in training each round.", + ) + parser.add_argument( + "--num-evaluate-clients", + default=5, + type=int, + help="Number of clients selected for evaluation.", + ) + parser.add_argument( + "--centralised-eval", + action="store_true", + help="Conduct centralised evaluation (True), or client evaluation on hold-out data (False).", + ) + parser.add_argument( + "--num-cpus-per-client", + default=2, + type=int, + help="Number of CPUs used for per client.", + ) + + # Client side + parser.add_argument( + "--partitioner-type", + default="uniform", + type=str, + choices=["uniform", "linear", "square", "exponential"], + help="Partitioner types.", + ) + parser.add_argument( + "--seed", default=42, type=int, help="Seed used for train/test splitting." + ) + parser.add_argument( + "--test-fraction", + default=0.2, + type=float, + help="Test fraction for train/test splitting.", + ) + parser.add_argument( + "--centralised-eval-client", + action="store_true", + help="Conduct evaluation on centralised test set (True), or on hold-out data (False).", + ) + parser.add_argument( + "--scaled-lr", + action="store_true", + help="Perform scaled learning rate based on the number of clients (True).", + ) + + args = parser.parse_args() + return args diff --git a/examples/xgboost-quickstart/README.md b/examples/xgboost-quickstart/README.md index 5174c236c668..72dde5706e8d 100644 --- a/examples/xgboost-quickstart/README.md +++ b/examples/xgboost-quickstart/README.md @@ -67,13 +67,13 @@ To do so simply open two more terminal windows and run the following commands. Start client 1 in the first terminal: ```shell -python3 client.py --node-id=0 +python3 client.py --partition-id=0 ``` Start client 2 in the second terminal: ```shell -python3 client.py --node-id=1 +python3 client.py --partition-id=1 ``` You will see that XGBoost is starting a federated training. @@ -85,4 +85,4 @@ poetry run ./run.sh ``` Look at the [code](https://github.com/adap/flower/tree/main/examples/xgboost-quickstart) -and [tutorial](https://flower.dev/docs/framework/tutorial-quickstart-xgboost.html) for a detailed explanation. +and [tutorial](https://flower.ai/docs/framework/tutorial-quickstart-xgboost.html) for a detailed explanation. diff --git a/examples/xgboost-quickstart/client.py b/examples/xgboost-quickstart/client.py index b5eab59ba14d..6ac23ae15148 100644 --- a/examples/xgboost-quickstart/client.py +++ b/examples/xgboost-quickstart/client.py @@ -24,13 +24,13 @@ warnings.filterwarnings("ignore", category=UserWarning) -# Define arguments parser for the client/node ID. +# Define arguments parser for the client/partition ID. parser = argparse.ArgumentParser() parser.add_argument( - "--node-id", + "--partition-id", default=0, type=int, - help="Node ID used for the current client.", + help="Partition ID used for the current client.", ) args = parser.parse_args() @@ -61,9 +61,9 @@ def transform_dataset_to_dmatrix(data: Union[Dataset, DatasetDict]) -> xgb.core. partitioner = IidPartitioner(num_partitions=30) fds = FederatedDataset(dataset="jxie/higgs", partitioners={"train": partitioner}) -# Load the partition for this `node_id` +# Load the partition for this `partition_id` log(INFO, "Loading partition...") -partition = fds.load_partition(node_id=args.node_id, split="train") +partition = fds.load_partition(partition_id=args.partition_id, split="train") partition.set_format("numpy") # Train/test splitting @@ -173,4 +173,4 @@ def evaluate(self, ins: EvaluateIns) -> EvaluateRes: # Start Flower client -fl.client.start_client(server_address="127.0.0.1:8080", client=XgbClient()) +fl.client.start_client(server_address="127.0.0.1:8080", client=XgbClient().to_client()) diff --git a/examples/xgboost-quickstart/pyproject.toml b/examples/xgboost-quickstart/pyproject.toml index 7b3cbd9659a2..c16542ea7ffe 100644 --- a/examples/xgboost-quickstart/pyproject.toml +++ b/examples/xgboost-quickstart/pyproject.toml @@ -6,10 +6,10 @@ build-backend = "poetry.core.masonry.api" name = "xgboost-quickstart" version = "0.1.0" description = "Federated XGBoost with Flower (quickstart)" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" -flwr = ">=1.6.0,<2.0" +flwr = ">=1.7.0,<2.0" flwr-datasets = ">=0.0.1,<1.0.0" xgboost = ">=2.0.0,<3.0.0" diff --git a/examples/xgboost-quickstart/requirements.txt b/examples/xgboost-quickstart/requirements.txt index 4ccd5587bfc3..c6949e0651c5 100644 --- a/examples/xgboost-quickstart/requirements.txt +++ b/examples/xgboost-quickstart/requirements.txt @@ -1,3 +1,3 @@ -flwr>=1.6.0, <2.0 +flwr>=1.7.0, <2.0 flwr-datasets>=0.0.1, <1.0.0 xgboost>=2.0.0, <3.0.0 diff --git a/examples/xgboost-quickstart/run.sh b/examples/xgboost-quickstart/run.sh index 6287145bfb5f..b35af58222ab 100755 --- a/examples/xgboost-quickstart/run.sh +++ b/examples/xgboost-quickstart/run.sh @@ -8,7 +8,7 @@ sleep 5 # Sleep for 5s to give the server enough time to start for i in `seq 0 1`; do echo "Starting client $i" - python3 client.py --node-id=$i & + python3 client.py --partition-id=$i & done # Enable CTRL+C to stop all background processes diff --git a/pyproject.toml b/pyproject.toml index 1ccdc72666f6..e0514254ecac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,14 +4,14 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "flwr" -version = "1.7.0" +version = "1.8.0" description = "Flower: A Friendly Federated Learning Framework" license = "Apache-2.0" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] readme = "README.md" -homepage = "https://flower.dev" +homepage = "https://flower.ai" repository = "https://github.com/adap/flower" -documentation = "https://flower.dev" +documentation = "https://flower.ai" keywords = [ "flower", "fl", @@ -52,10 +52,13 @@ exclude = [ ] [tool.poetry.scripts] +flwr = "flwr.cli.app:app" flower-driver-api = "flwr.server:run_driver_api" flower-fleet-api = "flwr.server:run_fleet_api" -flower-server = "flwr.server:run_server" -flower-client = "flwr.client:run_client" +flower-superlink = "flwr.server:run_superlink" +flower-client-app = "flwr.client:run_client_app" +flower-server-app = "flwr.server:run_server_app" +flower-simulation = "flwr.simulation:run_simulation_from_cli" [tool.poetry.dependencies] python = "^3.8" @@ -63,10 +66,12 @@ python = "^3.8" numpy = "^1.21.0" grpcio = "^1.60.0" protobuf = "^4.25.2" -cryptography = "^41.0.2" +cryptography = "^42.0.4" pycryptodome = "^3.18.0" iterators = "^0.0.2" -# Optional dependencies (VCE) +typer = { version = "^0.9.0", extras=["all"] } +tomli = "^2.0.1" +# Optional dependencies (Simulation Engine) ray = { version = "==2.6.3", optional = true } pydantic = { version = "<2.0.0", optional = true } # Optional dependencies (REST transport layer) @@ -81,21 +86,21 @@ rest = ["requests", "starlette", "uvicorn"] [tool.poetry.group.dev.dependencies] types-dataclasses = "==0.6.6" types-protobuf = "==3.19.18" -types-requests = "==2.31.0.10" -types-setuptools = "==69.0.0.20240115" -clang-format = "==17.0.4" +types-requests = "==2.31.0.20240125" +types-setuptools = "==69.0.0.20240125" +clang-format = "==17.0.6" isort = "==5.13.2" -black = { version = "==23.10.1", extras = ["jupyter"] } +black = { version = "==24.2.0", extras = ["jupyter"] } docformatter = "==1.7.5" mypy = "==1.8.0" pylint = "==3.0.3" flake8 = "==5.0.4" pytest = "==7.4.4" pytest-cov = "==4.1.0" -pytest-watch = "==4.2.0" +pytest-watcher = "==0.4.1" grpcio-tools = "==1.60.0" mypy-protobuf = "==3.2.0" -jupyterlab = "==4.0.9" +jupyterlab = "==4.0.12" rope = "==1.11.0" semver = "==3.0.2" sphinx = "==6.2.1" @@ -146,6 +151,16 @@ testpaths = [ "src/py/flwr", "src/py/flwr_tool", ] +filterwarnings = "ignore::DeprecationWarning" + +[tool.pytest-watcher] +now = false +clear = true +delay = 0.2 +runner = "pytest" +runner_args = ["-s", "-vvvvv"] +patterns = ["*.py"] +ignore_patterns = [] [tool.mypy] plugins = [ diff --git a/src/docker/server/Dockerfile b/src/docker/server/Dockerfile index c42246b16104..faa9cf2e56fe 100644 --- a/src/docker/server/Dockerfile +++ b/src/docker/server/Dockerfile @@ -7,8 +7,8 @@ FROM $BASE_REPOSITORY:$BASE_IMAGE_TAG as server WORKDIR /app ARG FLWR_VERSION RUN python -m pip install -U --no-cache-dir flwr[rest]==${FLWR_VERSION} -ENTRYPOINT ["python", "-c", "from flwr.server import run_server; run_server()"] +ENTRYPOINT ["python", "-c", "from flwr.server import run_superlink; run_superlink()"] # Test if Flower can be successfully installed and imported FROM server as test -RUN python -c "from flwr.server import run_server" +RUN python -c "from flwr.server import run_superlink" diff --git a/src/proto/flwr/proto/error.proto b/src/proto/flwr/proto/error.proto new file mode 100644 index 000000000000..a35af7f8af67 --- /dev/null +++ b/src/proto/flwr/proto/error.proto @@ -0,0 +1,23 @@ +// Copyright 2024 Flower Labs GmbH. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================== + +syntax = "proto3"; + +package flwr.proto; + +message Error { + sint64 code = 1; + string reason = 2; +} diff --git a/src/proto/flwr/proto/task.proto b/src/proto/flwr/proto/task.proto index 2cde16143d8d..423df76f1335 100644 --- a/src/proto/flwr/proto/task.proto +++ b/src/proto/flwr/proto/task.proto @@ -20,6 +20,7 @@ package flwr.proto; import "flwr/proto/node.proto"; import "flwr/proto/recordset.proto"; import "flwr/proto/transport.proto"; +import "flwr/proto/error.proto"; message Task { Node producer = 1; @@ -30,10 +31,7 @@ message Task { repeated string ancestry = 6; string task_type = 7; RecordSet recordset = 8; - - ServerMessage legacy_server_message = 101 [ deprecated = true ]; - ClientMessage legacy_client_message = 102 [ deprecated = true ]; - SecureAggregation sa = 103 [ deprecated = true ]; + Error error = 9; } message TaskIns { @@ -49,23 +47,3 @@ message TaskRes { sint64 run_id = 3; Task task = 4; } - -message Value { - oneof value { - // Single element - double double = 1; - sint64 sint64 = 2; - bool bool = 3; - string string = 4; - bytes bytes = 5; - - // List types - DoubleList double_list = 21; - Sint64List sint64_list = 22; - BoolList bool_list = 23; - StringList string_list = 24; - BytesList bytes_list = 25; - } -} - -message SecureAggregation { map named_values = 1; } diff --git a/src/py/flwr/__init__.py b/src/py/flwr/__init__.py index e05799280339..ccaf07c6012f 100644 --- a/src/py/flwr/__init__.py +++ b/src/py/flwr/__init__.py @@ -17,13 +17,11 @@ from flwr.common.version import package_version as _package_version -from . import client, common, driver, flower, server, simulation +from . import client, common, server, simulation __all__ = [ "client", "common", - "driver", - "flower", "server", "simulation", ] diff --git a/src/py/flwr/cli/__init__.py b/src/py/flwr/cli/__init__.py new file mode 100644 index 000000000000..d4d3b8ac4d48 --- /dev/null +++ b/src/py/flwr/cli/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower command line interface.""" diff --git a/src/py/flwr/cli/app.py b/src/py/flwr/cli/app.py new file mode 100644 index 000000000000..1c6ee6c97841 --- /dev/null +++ b/src/py/flwr/cli/app.py @@ -0,0 +1,37 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower command line interface.""" + +import typer + +from .example import example +from .new import new +from .run import run + +app = typer.Typer( + help=typer.style( + "flwr is the Flower command line interface.", + fg=typer.colors.BRIGHT_YELLOW, + bold=True, + ), + no_args_is_help=True, +) + +app.command()(new) +app.command()(example) +app.command()(run) + +if __name__ == "__main__": + app() diff --git a/src/py/flwr/cli/example.py b/src/py/flwr/cli/example.py new file mode 100644 index 000000000000..625ca8729640 --- /dev/null +++ b/src/py/flwr/cli/example.py @@ -0,0 +1,64 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower command line interface `example` command.""" + +import json +import os +import subprocess +import tempfile +import urllib.request + +from .utils import prompt_options + + +def example() -> None: + """Clone a Flower example. + + All examples available in the Flower repository are available through this command. + """ + # Load list of examples directly from GitHub + url = "https://api.github.com/repos/adap/flower/git/trees/main" + with urllib.request.urlopen(url) as res: + data = json.load(res) + examples_directory_url = [ + item["url"] for item in data["tree"] if item["path"] == "examples" + ][0] + + with urllib.request.urlopen(examples_directory_url) as res: + data = json.load(res) + example_names = [ + item["path"] for item in data["tree"] if item["path"] not in [".gitignore"] + ] + + example_name = prompt_options( + "Please select example by typing in the number", + example_names, + ) + + with tempfile.TemporaryDirectory() as tmpdirname: + subprocess.check_output( + [ + "git", + "clone", + "--depth=1", + "https://github.com/adap/flower.git", + tmpdirname, + ] + ) + examples_dir = os.path.join(tmpdirname, "examples", example_name) + subprocess.check_output(["mv", examples_dir, "."]) + + print() + print(f"Example ready to use in {os.path.join(os.getcwd(), example_name)}") diff --git a/src/py/flwr/cli/flower_toml.py b/src/py/flwr/cli/flower_toml.py new file mode 100644 index 000000000000..75d4b9f7e2cd --- /dev/null +++ b/src/py/flwr/cli/flower_toml.py @@ -0,0 +1,107 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility to validate the `flower.toml` file.""" + +import os +from typing import Any, Dict, List, Optional, Tuple + +import tomli + +from flwr.common.object_ref import validate + + +def load_flower_toml(path: Optional[str] = None) -> Optional[Dict[str, Any]]: + """Load flower.toml and return as dict.""" + if path is None: + cur_dir = os.getcwd() + toml_path = os.path.join(cur_dir, "flower.toml") + else: + toml_path = path + + if not os.path.isfile(toml_path): + return None + + with open(toml_path, encoding="utf-8") as toml_file: + data = tomli.loads(toml_file.read()) + return data + + +def validate_flower_toml_fields( + config: Dict[str, Any] +) -> Tuple[bool, List[str], List[str]]: + """Validate flower.toml fields.""" + errors = [] + warnings = [] + + if "project" not in config: + errors.append("Missing [project] section") + else: + if "name" not in config["project"]: + errors.append('Property "name" missing in [project]') + if "version" not in config["project"]: + errors.append('Property "version" missing in [project]') + if "description" not in config["project"]: + warnings.append('Recommended property "description" missing in [project]') + if "license" not in config["project"]: + warnings.append('Recommended property "license" missing in [project]') + if "authors" not in config["project"]: + warnings.append('Recommended property "authors" missing in [project]') + + if "flower" not in config: + errors.append("Missing [flower] section") + elif "components" not in config["flower"]: + errors.append("Missing [flower.components] section") + else: + if "serverapp" not in config["flower"]["components"]: + errors.append('Property "serverapp" missing in [flower.components]') + if "clientapp" not in config["flower"]["components"]: + errors.append('Property "clientapp" missing in [flower.components]') + + return len(errors) == 0, errors, warnings + + +def validate_flower_toml(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]]: + """Validate flower.toml.""" + is_valid, errors, warnings = validate_flower_toml_fields(config) + + if not is_valid: + return False, errors, warnings + + # Validate serverapp + is_valid, reason = validate(config["flower"]["components"]["serverapp"]) + if not is_valid and isinstance(reason, str): + return False, [reason], [] + + # Validate clientapp + is_valid, reason = validate(config["flower"]["components"]["clientapp"]) + + if not is_valid and isinstance(reason, str): + return False, [reason], [] + + return True, [], [] + + +def apply_defaults( + config: Dict[str, Any], + defaults: Dict[str, Any], +) -> Dict[str, Any]: + """Apply defaults to config.""" + for key in defaults: + if key in config: + if isinstance(config[key], dict) and isinstance(defaults[key], dict): + apply_defaults(config[key], defaults[key]) + else: + config[key] = defaults[key] + return config diff --git a/src/py/flwr/cli/flower_toml_test.py b/src/py/flwr/cli/flower_toml_test.py new file mode 100644 index 000000000000..67ccab97e59d --- /dev/null +++ b/src/py/flwr/cli/flower_toml_test.py @@ -0,0 +1,288 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test for Flower command line interface `run` command.""" + +import os +import textwrap +from typing import Any, Dict + +from .flower_toml import ( + load_flower_toml, + validate_flower_toml, + validate_flower_toml_fields, +) + + +def test_load_flower_toml_load_from_cwd(tmp_path: str) -> None: + """Test if load_template returns a string.""" + # Prepare + flower_toml_content = """ + [project] + name = "fedgpt" + + [flower.components] + serverapp = "fedgpt.server:app" + clientapp = "fedgpt.client:app" + + [flower.engine] + name = "simulation" # optional + + [flower.engine.simulation.supernode] + count = 10 # optional + """ + expected_config = { + "project": { + "name": "fedgpt", + }, + "flower": { + "components": { + "serverapp": "fedgpt.server:app", + "clientapp": "fedgpt.client:app", + }, + "engine": { + "name": "simulation", + "simulation": {"supernode": {"count": 10}}, + }, + }, + } + + # Current directory + origin = os.getcwd() + + try: + # Change into the temporary directory + os.chdir(tmp_path) + with open("flower.toml", "w", encoding="utf-8") as f: + f.write(textwrap.dedent(flower_toml_content)) + + # Execute + config = load_flower_toml() + + # Assert + assert config == expected_config + finally: + os.chdir(origin) + + +def test_load_flower_toml_from_path(tmp_path: str) -> None: + """Test if load_template returns a string.""" + # Prepare + flower_toml_content = """ + [project] + name = "fedgpt" + + [flower.components] + serverapp = "fedgpt.server:app" + clientapp = "fedgpt.client:app" + + [flower.engine] + name = "simulation" # optional + + [flower.engine.simulation.supernode] + count = 10 # optional + """ + expected_config = { + "project": { + "name": "fedgpt", + }, + "flower": { + "components": { + "serverapp": "fedgpt.server:app", + "clientapp": "fedgpt.client:app", + }, + "engine": { + "name": "simulation", + "simulation": {"supernode": {"count": 10}}, + }, + }, + } + + # Current directory + origin = os.getcwd() + + try: + # Change into the temporary directory + os.chdir(tmp_path) + with open("flower.toml", "w", encoding="utf-8") as f: + f.write(textwrap.dedent(flower_toml_content)) + + # Execute + config = load_flower_toml(path=os.path.join(tmp_path, "flower.toml")) + + # Assert + assert config == expected_config + finally: + os.chdir(origin) + + +def test_validate_flower_toml_fields_empty() -> None: + """Test that validate_flower_toml_fields fails correctly.""" + # Prepare + config: Dict[str, Any] = {} + + # Execute + is_valid, errors, warnings = validate_flower_toml_fields(config) + + # Assert + assert not is_valid + assert len(errors) == 2 + assert len(warnings) == 0 + + +def test_validate_flower_toml_fields_no_flower() -> None: + """Test that validate_flower_toml_fields fails correctly.""" + # Prepare + config = { + "project": { + "name": "fedgpt", + "version": "1.0.0", + "description": "", + "license": "", + "authors": [], + } + } + + # Execute + is_valid, errors, warnings = validate_flower_toml_fields(config) + + # Assert + assert not is_valid + assert len(errors) == 1 + assert len(warnings) == 0 + + +def test_validate_flower_toml_fields_no_flower_components() -> None: + """Test that validate_flower_toml_fields fails correctly.""" + # Prepare + config = { + "project": { + "name": "fedgpt", + "version": "1.0.0", + "description": "", + "license": "", + "authors": [], + }, + "flower": {}, + } + + # Execute + is_valid, errors, warnings = validate_flower_toml_fields(config) + + # Assert + assert not is_valid + assert len(errors) == 1 + assert len(warnings) == 0 + + +def test_validate_flower_toml_fields_no_server_and_client_app() -> None: + """Test that validate_flower_toml_fields fails correctly.""" + # Prepare + config = { + "project": { + "name": "fedgpt", + "version": "1.0.0", + "description": "", + "license": "", + "authors": [], + }, + "flower": {"components": {}}, + } + + # Execute + is_valid, errors, warnings = validate_flower_toml_fields(config) + + # Assert + assert not is_valid + assert len(errors) == 2 + assert len(warnings) == 0 + + +def test_validate_flower_toml_fields() -> None: + """Test that validate_flower_toml_fields succeeds correctly.""" + # Prepare + config = { + "project": { + "name": "fedgpt", + "version": "1.0.0", + "description": "", + "license": "", + "authors": [], + }, + "flower": {"components": {"serverapp": "", "clientapp": ""}}, + } + + # Execute + is_valid, errors, warnings = validate_flower_toml_fields(config) + + # Assert + assert is_valid + assert len(errors) == 0 + assert len(warnings) == 0 + + +def test_validate_flower_toml() -> None: + """Test that validate_flower_toml succeeds correctly.""" + # Prepare + config = { + "project": { + "name": "fedgpt", + "version": "1.0.0", + "description": "", + "license": "", + "authors": [], + }, + "flower": { + "components": { + "serverapp": "flwr.cli.run:run", + "clientapp": "flwr.cli.run:run", + } + }, + } + + # Execute + is_valid, errors, warnings = validate_flower_toml(config) + + # Assert + assert is_valid + assert not errors + assert not warnings + + +def test_validate_flower_toml_fail() -> None: + """Test that validate_flower_toml fails correctly.""" + # Prepare + config = { + "project": { + "name": "fedgpt", + "version": "1.0.0", + "description": "", + "license": "", + "authors": [], + }, + "flower": { + "components": { + "serverapp": "flwr.cli.run:run", + "clientapp": "flwr.cli.run:runa", + } + }, + } + + # Execute + is_valid, errors, warnings = validate_flower_toml(config) + + # Assert + assert not is_valid + assert len(errors) == 1 + assert len(warnings) == 0 diff --git a/src/py/flwr/client/run_state.py b/src/py/flwr/cli/new/__init__.py similarity index 72% rename from src/py/flwr/client/run_state.py rename to src/py/flwr/cli/new/__init__.py index c2755eb995eb..a973f47021c3 100644 --- a/src/py/flwr/client/run_state.py +++ b/src/py/flwr/cli/new/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,14 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Run state.""" +"""Flower command line interface `new` command.""" -from dataclasses import dataclass -from typing import Dict +from .new import new as new - -@dataclass -class RunState: - """State of a run executed by a client node.""" - - state: Dict[str, str] +__all__ = [ + "new", +] diff --git a/src/py/flwr/cli/new/new.py b/src/py/flwr/cli/new/new.py new file mode 100644 index 000000000000..8d644391ca5b --- /dev/null +++ b/src/py/flwr/cli/new/new.py @@ -0,0 +1,134 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower command line interface `new` command.""" + +import os +from enum import Enum +from string import Template +from typing import Dict, Optional + +import typer +from typing_extensions import Annotated + +from ..utils import prompt_options, prompt_text + + +class MlFramework(str, Enum): + """Available frameworks.""" + + NUMPY = "NumPy" + PYTORCH = "PyTorch" + TENSORFLOW = "TensorFlow" + + +class TemplateNotFound(Exception): + """Raised when template does not exist.""" + + +def load_template(name: str) -> str: + """Load template from template directory and return as text.""" + tpl_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "templates")) + tpl_file_path = os.path.join(tpl_dir, name) + + if not os.path.isfile(tpl_file_path): + raise TemplateNotFound(f"Template '{name}' not found") + + with open(tpl_file_path, encoding="utf-8") as tpl_file: + return tpl_file.read() + + +def render_template(template: str, data: Dict[str, str]) -> str: + """Render template.""" + tpl_file = load_template(template) + tpl = Template(tpl_file) + result = tpl.substitute(data) + return result + + +def create_file(file_path: str, content: str) -> None: + """Create file including all nessecary directories and write content into file.""" + os.makedirs(os.path.dirname(file_path), exist_ok=True) + with open(file_path, "w", encoding="utf-8") as f: + f.write(content) + + +def render_and_create(file_path: str, template: str, context: Dict[str, str]) -> None: + """Render template and write to file.""" + content = render_template(template, context) + create_file(file_path, content) + + +def new( + project_name: Annotated[ + Optional[str], + typer.Argument(metavar="project_name", help="The name of the project"), + ] = None, + framework: Annotated[ + Optional[MlFramework], + typer.Option(case_sensitive=False, help="The ML framework to use"), + ] = None, +) -> None: + """Create new Flower project.""" + print(f"Creating Flower project {project_name}...") + + if project_name is None: + project_name = prompt_text("Please provide project name") + + if framework is not None: + framework_str = str(framework.value) + else: + framework_value = prompt_options( + "Please select ML framework by typing in the number", + [mlf.value for mlf in MlFramework], + ) + selected_value = [ + name + for name, value in vars(MlFramework).items() + if value == framework_value + ] + framework_str = selected_value[0] + + framework_str = framework_str.lower() + + # Set project directory path + cwd = os.getcwd() + pnl = project_name.lower() + project_dir = os.path.join(cwd, pnl) + + # List of files to render + files = { + "README.md": {"template": "app/README.md.tpl"}, + "requirements.txt": {"template": f"app/requirements.{framework_str}.txt.tpl"}, + "flower.toml": {"template": "app/flower.toml.tpl"}, + "pyproject.toml": {"template": "app/pyproject.toml.tpl"}, + f"{pnl}/__init__.py": {"template": "app/code/__init__.py.tpl"}, + f"{pnl}/server.py": {"template": f"app/code/server.{framework_str}.py.tpl"}, + f"{pnl}/client.py": {"template": f"app/code/client.{framework_str}.py.tpl"}, + } + + # In case framework is MlFramework.PYTORCH generate additionally the task.py file + if framework_str == MlFramework.PYTORCH.value.lower(): + files[f"{pnl}/task.py"] = {"template": f"app/code/task.{framework_str}.py.tpl"} + + context = {"project_name": project_name} + + for file_path, value in files.items(): + render_and_create( + file_path=os.path.join(project_dir, file_path), + template=value["template"], + context=context, + ) + + print("Project creation successful.") diff --git a/src/py/flwr/cli/new/new_test.py b/src/py/flwr/cli/new/new_test.py new file mode 100644 index 000000000000..cedcb09b7755 --- /dev/null +++ b/src/py/flwr/cli/new/new_test.py @@ -0,0 +1,101 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test for Flower command line interface `new` command.""" + +import os + +from .new import MlFramework, create_file, load_template, new, render_template + + +def test_load_template() -> None: + """Test if load_template returns a string.""" + # Prepare + filename = "app/README.md.tpl" + + # Execute + text = load_template(filename) + + # Assert + assert isinstance(text, str) + + +def test_render_template() -> None: + """Test if a string is correctly substituted.""" + # Prepare + filename = "app/README.md.tpl" + data = {"project_name": "FedGPT"} + + # Execute + result = render_template(filename, data) + + # Assert + assert "# FedGPT" in result + + +def test_create_file(tmp_path: str) -> None: + """Test if file with content is created.""" + # Prepare + file_path = os.path.join(tmp_path, "test.txt") + content = "Foobar" + + # Execute + create_file(file_path, content) + + # Assert + with open(file_path, encoding="utf-8") as f: + text = f.read() + + assert text == "Foobar" + + +def test_new(tmp_path: str) -> None: + """Test if project is created for framework.""" + # Prepare + project_name = "FedGPT" + framework = MlFramework.PYTORCH + expected_files_top_level = { + "requirements.txt", + "fedgpt", + "README.md", + "flower.toml", + "pyproject.toml", + } + expected_files_module = { + "__init__.py", + "server.py", + "client.py", + "task.py", + } + + # Current directory + origin = os.getcwd() + + try: + # Change into the temprorary directory + os.chdir(tmp_path) + + # Execute + new(project_name=project_name, framework=framework) + + # Assert + file_list = os.listdir(os.path.join(tmp_path, project_name.lower())) + assert set(file_list) == expected_files_top_level + + file_list = os.listdir( + os.path.join(tmp_path, project_name.lower(), project_name.lower()) + ) + assert set(file_list) == expected_files_module + finally: + os.chdir(origin) diff --git a/src/py/flwr/cli/new/templates/__init__.py b/src/py/flwr/cli/new/templates/__init__.py new file mode 100644 index 000000000000..7a951c2da1a2 --- /dev/null +++ b/src/py/flwr/cli/new/templates/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower CLI `new` command templates.""" diff --git a/src/py/flwr/cli/new/templates/app/README.md.tpl b/src/py/flwr/cli/new/templates/app/README.md.tpl new file mode 100644 index 000000000000..6edb99a7f5ed --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/README.md.tpl @@ -0,0 +1,43 @@ +# $project_name + +## Install dependencies + +```bash +pip install -r requirements.txt +``` + +## Run (Simulation Engine) + +In the `$project_name` directory, use `flwr run` to run a local simulation: + +```bash +flwr run +``` + +## Run (Deployment Engine) + +### Start the SuperLink + +```bash +flower-superlink --insecure +``` + +### Start the long-running Flower client + +In a new terminal window, start the first long-running Flower client: + +```bash +flower-client-app client:app --insecure +``` + +In yet another new terminal window, start the second long-running Flower client: + +```bash +flower-client-app client:app --insecure +``` + +### Start the ServerApp + +```bash +flower-server-app server:app --insecure +``` diff --git a/src/py/flwr/cli/new/templates/app/__init__.py b/src/py/flwr/cli/new/templates/app/__init__.py new file mode 100644 index 000000000000..617628fc9138 --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower CLI `new` command app templates.""" diff --git a/src/py/flwr/cli/new/templates/app/code/__init__.py b/src/py/flwr/cli/new/templates/app/code/__init__.py new file mode 100644 index 000000000000..7f1a0e9f4fa2 --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/code/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower CLI `new` command app / code templates.""" diff --git a/src/py/flwr/cli/new/templates/app/code/__init__.py.tpl b/src/py/flwr/cli/new/templates/app/code/__init__.py.tpl new file mode 100644 index 000000000000..57998c81efb8 --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/code/__init__.py.tpl @@ -0,0 +1 @@ +"""$project_name.""" diff --git a/src/py/flwr/cli/new/templates/app/code/client.numpy.py.tpl b/src/py/flwr/cli/new/templates/app/code/client.numpy.py.tpl new file mode 100644 index 000000000000..232c305fc2a9 --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/code/client.numpy.py.tpl @@ -0,0 +1,23 @@ +"""$project_name: A Flower / NumPy app.""" + +from flwr.client import NumPyClient, ClientApp +import numpy as np + + +class FlowerClient(NumPyClient): + def get_parameters(self, config): + return [np.ones((1, 1))] + + def fit(self, parameters, config): + return ([np.ones((1, 1))], 1, {}) + + def evaluate(self, parameters, config): + return float(0.0), 1, {"accuracy": float(1.0)} + + +def client_fn(cid: str): + return FlowerClient().to_client() + + +# Flower ClientApp +app = ClientApp(client_fn=client_fn) diff --git a/src/py/flwr/cli/new/templates/app/code/client.pytorch.py.tpl b/src/py/flwr/cli/new/templates/app/code/client.pytorch.py.tpl new file mode 100644 index 000000000000..bdb5b8fcadf9 --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/code/client.pytorch.py.tpl @@ -0,0 +1,44 @@ +"""$project_name: A Flower / PyTorch app.""" + +from flwr.client import NumPyClient, ClientApp + +from $project_name.task import ( + Net, + DEVICE, + load_data, + get_weights, + set_weights, + train, + test, +) + + +# Define Flower Client and client_fn +class FlowerClient(NumPyClient): + def __init__(self, net, trainloader, valloader) -> None: + self.net = net + self.trainloader = trainloader + self.valloader = valloader + + def fit(self, parameters, config): + set_weights(self.net, parameters) + results = train(self.net, self.trainloader, self.valloader, 1, DEVICE) + return get_weights(self.net), len(self.trainloader.dataset), results + + def evaluate(self, parameters, config): + set_weights(self.net, parameters) + loss, accuracy = test(self.net, self.valloader) + return loss, len(self.valloader.dataset), {"accuracy": accuracy} + + +def client_fn(cid: str): + # Load model and data + net = Net().to(DEVICE) + trainloader, valloader = load_data() + + # Return Client instance + return FlowerClient(net, trainloader, valloader).to_client() + + +# Flower ClientApp +app = ClientApp(client_fn) diff --git a/src/py/flwr/cli/new/templates/app/code/client.tensorflow.py.tpl b/src/py/flwr/cli/new/templates/app/code/client.tensorflow.py.tpl new file mode 100644 index 000000000000..cc00f8ff0b8c --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/code/client.tensorflow.py.tpl @@ -0,0 +1 @@ +"""$project_name: A Flower / TensorFlow app.""" diff --git a/src/py/flwr/cli/new/templates/app/code/server.numpy.py.tpl b/src/py/flwr/cli/new/templates/app/code/server.numpy.py.tpl new file mode 100644 index 000000000000..03f95ae35cfd --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/code/server.numpy.py.tpl @@ -0,0 +1,12 @@ +"""$project_name: A Flower / NumPy app.""" + +import flwr as fl + +# Configure the strategy +strategy = fl.server.strategy.FedAvg() + +# Flower ServerApp +app = fl.server.ServerApp( + config=fl.server.ServerConfig(num_rounds=1), + strategy=strategy, +) diff --git a/src/py/flwr/cli/new/templates/app/code/server.pytorch.py.tpl b/src/py/flwr/cli/new/templates/app/code/server.pytorch.py.tpl new file mode 100644 index 000000000000..cb04c052b429 --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/code/server.pytorch.py.tpl @@ -0,0 +1,28 @@ +"""$project_name: A Flower / PyTorch app.""" + +from flwr.common import ndarrays_to_parameters +from flwr.server import ServerApp, ServerConfig +from flwr.server.strategy import FedAvg + +from $project_name.task import Net, get_weights + + +# Initialize model parameters +ndarrays = get_weights(Net()) +parameters = ndarrays_to_parameters(ndarrays) + + +# Define strategy +strategy = FedAvg( + fraction_fit=1.0, + fraction_evaluate=1.0, + min_available_clients=2, + initial_parameters=parameters, +) + + +# Create ServerApp +app = ServerApp( + config=ServerConfig(num_rounds=3), + strategy=strategy, +) diff --git a/src/py/flwr/cli/new/templates/app/code/server.tensorflow.py.tpl b/src/py/flwr/cli/new/templates/app/code/server.tensorflow.py.tpl new file mode 100644 index 000000000000..cc00f8ff0b8c --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/code/server.tensorflow.py.tpl @@ -0,0 +1 @@ +"""$project_name: A Flower / TensorFlow app.""" diff --git a/src/py/flwr/cli/new/templates/app/code/task.pytorch.py.tpl b/src/py/flwr/cli/new/templates/app/code/task.pytorch.py.tpl new file mode 100644 index 000000000000..1d727599a1e4 --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/code/task.pytorch.py.tpl @@ -0,0 +1,94 @@ +"""$project_name: A Flower / PyTorch app.""" + +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader +from torchvision.datasets import CIFAR10 +from torchvision.transforms import Compose, Normalize, ToTensor + + +DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +class Net(nn.Module): + """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" + + def __init__(self) -> None: + super(Net, self).__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = x.view(-1, 16 * 5 * 5) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + return self.fc3(x) + + +def load_data(): + """Load CIFAR-10 (training and test set).""" + trf = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + trainset = CIFAR10("./data", train=True, download=True, transform=trf) + testset = CIFAR10("./data", train=False, download=True, transform=trf) + return DataLoader(trainset, batch_size=32, shuffle=True), DataLoader(testset) + + +def train(net, trainloader, valloader, epochs, device): + """Train the model on the training set.""" + print("Starting training...") + net.to(device) # move model to GPU if available + criterion = torch.nn.CrossEntropyLoss().to(device) + optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9) + net.train() + for _ in range(epochs): + for images, labels in trainloader: + images, labels = images.to(device), labels.to(device) + optimizer.zero_grad() + loss = criterion(net(images), labels) + loss.backward() + optimizer.step() + + train_loss, train_acc = test(net, trainloader) + val_loss, val_acc = test(net, valloader) + + results = { + "train_loss": train_loss, + "train_accuracy": train_acc, + "val_loss": val_loss, + "val_accuracy": val_acc, + } + return results + + +def test(net, testloader): + """Validate the model on the test set.""" + net.to(DEVICE) + criterion = torch.nn.CrossEntropyLoss() + correct, loss = 0, 0.0 + with torch.no_grad(): + for images, labels in testloader: + outputs = net(images.to(DEVICE)) + labels = labels.to(DEVICE) + loss += criterion(outputs, labels).item() + correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() + accuracy = correct / len(testloader.dataset) + return loss, accuracy + + +def get_weights(net): + return [val.cpu().numpy() for _, val in net.state_dict().items()] + + +def set_weights(net, parameters): + params_dict = zip(net.state_dict().keys(), parameters) + state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) + net.load_state_dict(state_dict, strict=True) diff --git a/src/py/flwr/cli/new/templates/app/flower.toml.tpl b/src/py/flwr/cli/new/templates/app/flower.toml.tpl new file mode 100644 index 000000000000..07a6ffaf9e49 --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/flower.toml.tpl @@ -0,0 +1,13 @@ +[project] +name = "$project_name" +version = "1.0.0" +description = "" +license = "Apache-2.0" +authors = [ + "The Flower Authors ", +] +readme = "README.md" + +[flower.components] +serverapp = "$project_name.server:app" +clientapp = "$project_name.client:app" diff --git a/src/py/flwr/cli/new/templates/app/pyproject.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.toml.tpl new file mode 100644 index 000000000000..ca3f625e2437 --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/pyproject.toml.tpl @@ -0,0 +1,21 @@ +[build-system] +requires = ["poetry-core>=1.4.0"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "$project_name" +version = "1.0.0" +description = "" +license = "Apache-2.0" +authors = [ + "The Flower Authors ", +] +readme = "README.md" + +[tool.poetry.dependencies] +python = "^3.9" +# Mandatory dependencies +flwr-nightly = { version = "1.8.0.dev20240309", extras = ["simulation"] } +flwr-datasets = { version = "^0.0.2", extras = ["vision"] } +torch = "2.2.1" +torchvision = "0.17.1" diff --git a/src/py/flwr/cli/new/templates/app/requirements.numpy.txt.tpl b/src/py/flwr/cli/new/templates/app/requirements.numpy.txt.tpl new file mode 100644 index 000000000000..4b460798e96f --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/requirements.numpy.txt.tpl @@ -0,0 +1,2 @@ +flwr>=1.8, <2.0 +numpy>=1.21.0 diff --git a/src/py/flwr/cli/new/templates/app/requirements.pytorch.txt.tpl b/src/py/flwr/cli/new/templates/app/requirements.pytorch.txt.tpl new file mode 100644 index 000000000000..016a84043cbe --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/requirements.pytorch.txt.tpl @@ -0,0 +1,4 @@ +flwr-nightly[simulation]==1.8.0.dev20240309 +flwr-datasets[vision]==0.0.2 +torch==2.2.1 +torchvision==0.17.1 diff --git a/src/py/flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl b/src/py/flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl new file mode 100644 index 000000000000..b6fb49a4bbcb --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl @@ -0,0 +1,4 @@ +flwr>=1.8, <2.0 +flwr-datasets[vision]>=0.0.2, <1.0.0 +tensorflow-macos>=2.9.1, !=2.11.1 ; sys_platform == "darwin" and platform_machine == "arm64" +tensorflow-cpu>=2.9.1, !=2.11.1 ; platform_machine == "x86_64" diff --git a/src/py/flwr/cli/run/__init__.py b/src/py/flwr/cli/run/__init__.py new file mode 100644 index 000000000000..43523c215d3e --- /dev/null +++ b/src/py/flwr/cli/run/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower command line interface `run` command.""" + +from .run import run as run + +__all__ = [ + "run", +] diff --git a/src/py/flwr/cli/run/run.py b/src/py/flwr/cli/run/run.py new file mode 100644 index 000000000000..d0838d18d7e4 --- /dev/null +++ b/src/py/flwr/cli/run/run.py @@ -0,0 +1,102 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower command line interface `run` command.""" + +import sys + +import typer + +from flwr.cli.flower_toml import apply_defaults, load_flower_toml, validate_flower_toml +from flwr.simulation.run_simulation import _run_simulation + + +def run() -> None: + """Run Flower project.""" + print( + typer.style("Loading project configuration... ", fg=typer.colors.BLUE), + end="", + ) + config = load_flower_toml() + if not config: + print( + typer.style( + "Project configuration could not be loaded. " + "flower.toml does not exist.", + fg=typer.colors.RED, + bold=True, + ) + ) + sys.exit() + print(typer.style("Success", fg=typer.colors.GREEN)) + + print( + typer.style("Validating project configuration... ", fg=typer.colors.BLUE), + end="", + ) + is_valid, errors, warnings = validate_flower_toml(config) + if warnings: + print( + typer.style( + "Project configuration is missing the following " + "recommended properties:\n" + + "\n".join([f"- {line}" for line in warnings]), + fg=typer.colors.RED, + bold=True, + ) + ) + + if not is_valid: + print( + typer.style( + "Project configuration could not be loaded.\nflower.toml is invalid:\n" + + "\n".join([f"- {line}" for line in errors]), + fg=typer.colors.RED, + bold=True, + ) + ) + sys.exit() + print(typer.style("Success", fg=typer.colors.GREEN)) + + # Apply defaults + defaults = { + "flower": { + "engine": {"name": "simulation", "simulation": {"supernode": {"num": 2}}} + } + } + config = apply_defaults(config, defaults) + + server_app_ref = config["flower"]["components"]["serverapp"] + client_app_ref = config["flower"]["components"]["clientapp"] + engine = config["flower"]["engine"]["name"] + + if engine == "simulation": + num_supernodes = config["flower"]["engine"]["simulation"]["supernode"]["num"] + + print( + typer.style("Starting run... ", fg=typer.colors.BLUE), + ) + _run_simulation( + server_app_attr=server_app_ref, + client_app_attr=client_app_ref, + num_supernodes=num_supernodes, + ) + else: + print( + typer.style( + f"Engine '{engine}' is not yet supported in `flwr run`", + fg=typer.colors.RED, + bold=True, + ) + ) diff --git a/src/py/flwr/cli/utils.py b/src/py/flwr/cli/utils.py new file mode 100644 index 000000000000..4e86f0c3b8c8 --- /dev/null +++ b/src/py/flwr/cli/utils.py @@ -0,0 +1,67 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower command line interface utils.""" + +from typing import List, cast + +import typer + + +def prompt_text(text: str) -> str: + """Ask user to enter text input.""" + while True: + result = typer.prompt( + typer.style(f"\n💬 {text}", fg=typer.colors.MAGENTA, bold=True) + ) + if len(result) > 0: + break + print(typer.style("❌ Invalid entry", fg=typer.colors.RED, bold=True)) + + return cast(str, result) + + +def prompt_options(text: str, options: List[str]) -> str: + """Ask user to select one of the given options and return the selected item.""" + # Turn options into a list with index as in " [ 0] quickstart-pytorch" + options_formatted = [ + " [ " + + typer.style(index, fg=typer.colors.GREEN, bold=True) + + "]" + + f" {typer.style(name, fg=typer.colors.WHITE, bold=True)}" + for index, name in enumerate(options) + ] + + while True: + index = typer.prompt( + "\n" + + typer.style(f"💬 {text}", fg=typer.colors.MAGENTA, bold=True) + + "\n\n" + + "\n".join(options_formatted) + + "\n\n\n" + ) + try: + options[int(index)] # pylint: disable=expression-not-assigned + break + except IndexError: + print(typer.style("❌ Index out of range", fg=typer.colors.RED, bold=True)) + continue + except ValueError: + print( + typer.style("❌ Please choose a number", fg=typer.colors.RED, bold=True) + ) + continue + + result = options[int(index)] + return result diff --git a/src/py/flwr/client/__init__.py b/src/py/flwr/client/__init__.py index 13540a76cc25..a721fb584164 100644 --- a/src/py/flwr/client/__init__.py +++ b/src/py/flwr/client/__init__.py @@ -15,18 +15,20 @@ """Flower client.""" -from .app import run_client as run_client +from .app import run_client_app as run_client_app from .app import start_client as start_client from .app import start_numpy_client as start_numpy_client from .client import Client as Client +from .client_app import ClientApp as ClientApp from .numpy_client import NumPyClient as NumPyClient from .typing import ClientFn as ClientFn __all__ = [ "Client", + "ClientApp", "ClientFn", "NumPyClient", - "run_client", + "run_client_app", "start_client", "start_numpy_client", ] diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index ae5beeae07d6..c8287afc0fd0 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -18,14 +18,16 @@ import argparse import sys import time -from logging import INFO, WARN +from logging import DEBUG, INFO, WARN from pathlib import Path -from typing import Callable, ContextManager, Optional, Tuple, Union +from typing import Callable, ContextManager, Optional, Tuple, Type, Union + +from grpc import RpcError from flwr.client.client import Client -from flwr.client.flower import Flower -from flwr.client.typing import Bwd, ClientFn, Fwd -from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event +from flwr.client.client_app import ClientApp, LoadClientAppError +from flwr.client.typing import ClientFn +from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, Message, event from flwr.common.address import parse_address from flwr.common.constant import ( MISSING_EXTRA_REST, @@ -34,10 +36,11 @@ TRANSPORT_TYPE_REST, TRANSPORT_TYPES, ) -from flwr.common.logger import log, warn_experimental_feature -from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 +from flwr.common.exit_handlers import register_exit_handlers +from flwr.common.logger import log, warn_deprecated_feature, warn_experimental_feature +from flwr.common.object_ref import load_app, validate +from flwr.common.retry_invoker import RetryInvoker, exponential -from .flower import load_flower_callable from .grpc_client.connection import grpc_connection from .grpc_rere_client.connection import grpc_request_response from .message_handler.message_handler import handle_control_message @@ -45,13 +48,13 @@ from .numpy_client import NumPyClient -def run_client() -> None: - """Run Flower client.""" - event(EventType.RUN_CLIENT_ENTER) +def run_client_app() -> None: + """Run Flower client app.""" + event(EventType.RUN_CLIENT_APP_ENTER) log(INFO, "Long-running Flower client starting") - args = _parse_args_client().parse_args() + args = _parse_args_run_client_app().parse_args() # Obtain certificates if args.insecure: @@ -62,7 +65,12 @@ def run_client() -> None: "the '--root-certificates' option when running in insecure mode, " "or omit '--insecure' to use HTTPS." ) - log(WARN, "Option `--insecure` was set. Starting insecure HTTP client.") + log( + WARN, + "Option `--insecure` was set. " + "Starting insecure HTTP client connected to %s.", + args.server, + ) root_certificates = None else: # Load the certificates if provided, or load the system certificates @@ -71,39 +79,60 @@ def run_client() -> None: root_certificates = None else: root_certificates = Path(cert_path).read_bytes() + log( + DEBUG, + "Starting secure HTTPS client connected to %s " + "with the following certificates: %s.", + args.server, + cert_path, + ) + + log( + DEBUG, + "Flower will load ClientApp `%s`", + getattr(args, "client-app"), + ) + + client_app_dir = args.dir + if client_app_dir is not None: + sys.path.insert(0, client_app_dir) + + app_ref: str = getattr(args, "client-app") + valid, error_msg = validate(app_ref) + if not valid and error_msg: + raise LoadClientAppError(error_msg) from None - print(args.root_certificates) - print(args.server) - print(args.dir) - print(args.callable) + def _load() -> ClientApp: + client_app = load_app(app_ref, LoadClientAppError) - callable_dir = args.dir - if callable_dir is not None: - sys.path.insert(0, callable_dir) + if not isinstance(client_app, ClientApp): + raise LoadClientAppError( + f"Attribute {app_ref} is not of type {ClientApp}", + ) from None - def _load() -> Flower: - flower: Flower = load_flower_callable(args.callable) - return flower + return client_app _start_client_internal( server_address=args.server, - load_flower_callable_fn=_load, - transport="grpc-rere", # Only + load_client_app_fn=_load, + transport="rest" if args.rest else "grpc-rere", root_certificates=root_certificates, insecure=args.insecure, + max_retries=args.max_retries, + max_wait_time=args.max_wait_time, ) - event(EventType.RUN_CLIENT_LEAVE) + register_exit_handlers(event_type=EventType.RUN_CLIENT_APP_LEAVE) -def _parse_args_client() -> argparse.ArgumentParser: - """Parse command line arguments.""" +def _parse_args_run_client_app() -> argparse.ArgumentParser: + """Parse flower-client-app command line arguments.""" parser = argparse.ArgumentParser( - description="Start a long-running Flower client", + description="Start a Flower client app", ) parser.add_argument( - "callable", - help="For example: `client:flower` or `project.package.module:wrapper.flower`", + "client-app", + help="For example: `client:app` or `project.package.module:wrapper.app`", ) parser.add_argument( "--insecure", @@ -111,6 +140,11 @@ def _parse_args_client() -> argparse.ArgumentParser: help="Run the client without HTTPS. By default, the client runs with " "HTTPS enabled. Use this flag only if you understand the risks.", ) + parser.add_argument( + "--rest", + action="store_true", + help="Use REST as a transport layer for the client.", + ) parser.add_argument( "--root-certificates", metavar="ROOT_CERT", @@ -123,11 +157,27 @@ def _parse_args_client() -> argparse.ArgumentParser: default="0.0.0.0:9092", help="Server address", ) + parser.add_argument( + "--max-retries", + type=int, + default=None, + help="The maximum number of times the client will try to connect to the" + "server before giving up in case of a connection error. By default," + "it is set to None, meaning there is no limit to the number of tries.", + ) + parser.add_argument( + "--max-wait-time", + type=float, + default=None, + help="The maximum duration before the client stops trying to" + "connect to the server in case of connection error. By default, it" + "is set to None, meaning there is no limit to the total time.", + ) parser.add_argument( "--dir", default="", help="Add specified directory to the PYTHONPATH and load Flower " - "callable from there." + "app from there." " Default: current working directory.", ) @@ -162,6 +212,8 @@ def start_client( root_certificates: Optional[Union[bytes, str]] = None, insecure: Optional[bool] = None, transport: Optional[str] = None, + max_retries: Optional[int] = None, + max_wait_time: Optional[float] = None, ) -> None: """Start a Flower client node which connects to a Flower server. @@ -195,6 +247,14 @@ class `flwr.client.Client` (default: None) - 'grpc-bidi': gRPC, bidirectional streaming - 'grpc-rere': gRPC, request-response (experimental) - 'rest': HTTP (experimental) + max_retries: Optional[int] (default: None) + The maximum number of times the client will try to connect to the + server before giving up in case of a connection error. If set to None, + there is no limit to the number of tries. + max_wait_time: Optional[float] (default: None) + The maximum duration before the client stops trying to + connect to the server in case of connection error. + If set to None, there is no limit to the total time. Examples -------- @@ -229,13 +289,15 @@ class `flwr.client.Client` (default: None) event(EventType.START_CLIENT_ENTER) _start_client_internal( server_address=server_address, - load_flower_callable_fn=None, + load_client_app_fn=None, client_fn=client_fn, client=client, grpc_max_message_length=grpc_max_message_length, root_certificates=root_certificates, insecure=insecure, transport=transport, + max_retries=max_retries, + max_wait_time=max_wait_time, ) event(EventType.START_CLIENT_LEAVE) @@ -247,13 +309,15 @@ class `flwr.client.Client` (default: None) def _start_client_internal( *, server_address: str, - load_flower_callable_fn: Optional[Callable[[], Flower]] = None, + load_client_app_fn: Optional[Callable[[], ClientApp]] = None, client_fn: Optional[ClientFn] = None, client: Optional[Client] = None, grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, root_certificates: Optional[Union[bytes, str]] = None, insecure: Optional[bool] = None, transport: Optional[str] = None, + max_retries: Optional[int] = None, + max_wait_time: Optional[float] = None, ) -> None: """Start a Flower client node which connects to a Flower server. @@ -263,8 +327,8 @@ def _start_client_internal( The IPv4 or IPv6 address of the server. If the Flower server runs on the same machine on port 8080, then `server_address` would be `"[::]:8080"`. - load_flower_callable_fn : Optional[Callable[[], Flower]] (default: None) - A function that can be used to load a `Flower` callable instance. + load_client_app_fn : Optional[Callable[[], ClientApp]] (default: None) + A function that can be used to load a `ClientApp` instance. client_fn : Optional[ClientFn] A callable that instantiates a Client. (default: None) client : Optional[flwr.client.Client] @@ -281,7 +345,7 @@ class `flwr.client.Client` (default: None) The PEM-encoded root certificates as a byte string or a path string. If provided, a secure connection using the certificates will be established to an SSL-enabled Flower server. - insecure : bool (default: True) + insecure : Optional[bool] (default: None) Starts an insecure gRPC connection when True. Enables HTTPS connection when False, using system certificates if `root_certificates` is None. transport : Optional[str] (default: None) @@ -289,11 +353,19 @@ class `flwr.client.Client` (default: None) - 'grpc-bidi': gRPC, bidirectional streaming - 'grpc-rere': gRPC, request-response (experimental) - 'rest': HTTP (experimental) + max_retries: Optional[int] (default: None) + The maximum number of times the client will try to connect to the + server before giving up in case of a connection error. If set to None, + there is no limit to the number of tries. + max_wait_time: Optional[float] (default: None) + The maximum duration before the client stops trying to + connect to the server in case of connection error. + If set to None, there is no limit to the total time. """ if insecure is None: insecure = root_certificates is None - if load_flower_callable_fn is None: + if load_client_app_fn is None: _check_actionable_client(client, client_fn) if client_fn is None: @@ -309,18 +381,56 @@ def single_client_factory( client_fn = single_client_factory - def _load_app() -> Flower: - return Flower(client_fn=client_fn) + def _load_client_app() -> ClientApp: + return ClientApp(client_fn=client_fn) - load_flower_callable_fn = _load_app + load_client_app_fn = _load_client_app else: - warn_experimental_feature("`load_flower_callable_fn`") + warn_experimental_feature("`load_client_app_fn`") - # At this point, only `load_flower_callable_fn` should be used + # At this point, only `load_client_app_fn` should be used # Both `client` and `client_fn` must not be used directly # Initialize connection context manager - connection, address = _init_connection(transport, server_address) + connection, address, connection_error_type = _init_connection( + transport, server_address + ) + + retry_invoker = RetryInvoker( + wait_factory=exponential, + recoverable_exceptions=connection_error_type, + max_tries=max_retries, + max_time=max_wait_time, + on_giveup=lambda retry_state: ( + log( + WARN, + "Giving up reconnection after %.2f seconds and %s tries.", + retry_state.elapsed_time, + retry_state.tries, + ) + if retry_state.tries > 1 + else None + ), + on_success=lambda retry_state: ( + log( + INFO, + "Connection successful after %.2f seconds and %s tries.", + retry_state.elapsed_time, + retry_state.tries, + ) + if retry_state.tries > 1 + else None + ), + on_backoff=lambda retry_state: ( + log(WARN, "Connection attempt failed, retrying...") + if retry_state.tries == 1 + else log( + DEBUG, + "Connection attempt failed, retrying in %.2f seconds", + retry_state.actual_wait, + ) + ), + ) node_state = NodeState() @@ -329,6 +439,7 @@ def _load_app() -> Flower: with connection( address, insecure, + retry_invoker, grpc_max_message_length, root_certificates, ) as conn: @@ -340,38 +451,63 @@ def _load_app() -> Flower: while True: # Receive - task_ins = receive() - if task_ins is None: + message = receive() + if message is None: time.sleep(3) # Wait for 3s before asking again continue + log(INFO, "") + log( + INFO, + "[RUN %s, ROUND %s]", + message.metadata.run_id, + message.metadata.group_id, + ) + log( + INFO, + "Received: %s message %s", + message.metadata.message_type, + message.metadata.message_id, + ) + # Handle control message - task_res, sleep_duration = handle_control_message(task_ins=task_ins) - if task_res: - send(task_res) + out_message, sleep_duration = handle_control_message(message) + if out_message: + send(out_message) break - # Register state - node_state.register_runstate(run_id=task_ins.run_id) + # Register context for this run + node_state.register_context(run_id=message.metadata.run_id) - # Load app - app: Flower = load_flower_callable_fn() + # Retrieve context for this run + context = node_state.retrieve_context(run_id=message.metadata.run_id) + + # Load ClientApp instance + client_app: ClientApp = load_client_app_fn() # Handle task message - fwd_msg: Fwd = Fwd( - task_ins=task_ins, - state=node_state.retrieve_runstate(run_id=task_ins.run_id), - ) - bwd_msg: Bwd = app(fwd=fwd_msg) + out_message = client_app(message=message, context=context) # Update node state - node_state.update_runstate( - run_id=bwd_msg.task_res.run_id, - run_state=bwd_msg.state, + node_state.update_context( + run_id=message.metadata.run_id, + context=context, ) # Send - send(bwd_msg.task_res) + send(out_message) + log( + INFO, + "[RUN %s, ROUND %s]", + out_message.metadata.run_id, + out_message.metadata.group_id, + ) + log( + INFO, + "Sent: %s reply to message %s", + out_message.metadata.message_type, + message.metadata.message_id, + ) # Unregister node if delete_node is not None: @@ -400,6 +536,12 @@ def start_numpy_client( ) -> None: """Start a Flower NumPyClient which connects to a gRPC server. + Warning + ------- + This function is deprecated since 1.7.0. Use :code:`flwr.client.start_client` + instead and first convert your :code:`NumPyClient` to type + :code:`flwr.client.Client` by executing its :code:`to_client()` method. + Parameters ---------- server_address : str @@ -455,21 +597,22 @@ def start_numpy_client( >>> root_certificates=Path("/crts/root.pem").read_bytes(), >>> ) """ - # warnings.warn( - # "flwr.client.start_numpy_client() is deprecated and will " - # "be removed in a future version of Flower. Instead, pass " - # "your client to `flwr.client.start_client()` by calling " - # "first the `.to_client()` method as shown below: \n" - # "\tflwr.client.start_client(\n" - # "\t\tserver_address=':',\n" - # "\t\tclient=FlowerClient().to_client()\n" - # "\t)", - # DeprecationWarning, - # stacklevel=2, - # ) + mssg = ( + "flwr.client.start_numpy_client() is deprecated. \n\tInstead, use " + "`flwr.client.start_client()` by ensuring you first call " + "the `.to_client()` method as shown below: \n" + "\tflwr.client.start_client(\n" + "\t\tserver_address=':',\n" + "\t\tclient=FlowerClient().to_client()," + " # <-- where FlowerClient is of type flwr.client.NumPyClient object\n" + "\t)\n" + "\tUsing `start_numpy_client()` is deprecated." + ) + + warn_deprecated_feature(name=mssg) # Calling this function is deprecated. A warning is thrown. - # We first need to convert either the supplied client to `Client.` + # We first need to convert the supplied client to `Client.` wrp_client = client.to_client() @@ -483,21 +626,20 @@ def start_numpy_client( ) -def _init_connection( - transport: Optional[str], server_address: str -) -> Tuple[ +def _init_connection(transport: Optional[str], server_address: str) -> Tuple[ Callable[ - [str, bool, int, Union[bytes, str, None]], + [str, bool, RetryInvoker, int, Union[bytes, str, None]], ContextManager[ Tuple[ - Callable[[], Optional[TaskIns]], - Callable[[TaskRes], None], + Callable[[], Optional[Message]], + Callable[[Message], None], Optional[Callable[[], None]], Optional[Callable[[], None]], ] ], ], str, + Type[Exception], ]: # Parse IP address parsed_address = parse_address(server_address) @@ -513,6 +655,8 @@ def _init_connection( # Use either gRPC bidirectional streaming or REST request/response if transport == TRANSPORT_TYPE_REST: try: + from requests.exceptions import ConnectionError as RequestsConnectionError + from .rest_client.connection import http_request_response except ModuleNotFoundError: sys.exit(MISSING_EXTRA_REST) @@ -521,14 +665,14 @@ def _init_connection( "When using the REST API, please provide `https://` or " "`http://` before the server address (e.g. `http://127.0.0.1:8080`)" ) - connection = http_request_response + connection, error_type = http_request_response, RequestsConnectionError elif transport == TRANSPORT_TYPE_GRPC_RERE: - connection = grpc_request_response + connection, error_type = grpc_request_response, RpcError elif transport == TRANSPORT_TYPE_GRPC_BIDI: - connection = grpc_connection + connection, error_type = grpc_connection, RpcError else: raise ValueError( f"Unknown transport type: {transport} (possible: {TRANSPORT_TYPES})" ) - return connection, address + return connection, address, error_type diff --git a/src/py/flwr/client/client.py b/src/py/flwr/client/client.py index 54b53296fd2f..23a3755f3efe 100644 --- a/src/py/flwr/client/client.py +++ b/src/py/flwr/client/client.py @@ -19,9 +19,9 @@ from abc import ABC -from flwr.client.run_state import RunState from flwr.common import ( Code, + Context, EvaluateIns, EvaluateRes, FitIns, @@ -38,7 +38,7 @@ class Client(ABC): """Abstract base class for Flower clients.""" - state: RunState + context: Context def get_properties(self, ins: GetPropertiesIns) -> GetPropertiesRes: """Return set of client's properties. @@ -141,13 +141,13 @@ def evaluate(self, ins: EvaluateIns) -> EvaluateRes: metrics={}, ) - def get_state(self) -> RunState: - """Get the run state from this client.""" - return self.state + def get_context(self) -> Context: + """Get the run context from this client.""" + return self.context - def set_state(self, state: RunState) -> None: - """Apply a run state to this client.""" - self.state = state + def set_context(self, context: Context) -> None: + """Apply a run context to this client.""" + self.context = context def to_client(self) -> Client: """Return client (itself).""" diff --git a/src/py/flwr/client/client_app.py b/src/py/flwr/client/client_app.py new file mode 100644 index 000000000000..ad7a01326991 --- /dev/null +++ b/src/py/flwr/client/client_app.py @@ -0,0 +1,224 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower ClientApp.""" + + +from typing import Callable, List, Optional + +from flwr.client.message_handler.message_handler import ( + handle_legacy_message_from_msgtype, +) +from flwr.client.mod.utils import make_ffn +from flwr.client.typing import ClientFn, Mod +from flwr.common import Context, Message, MessageType + +from .typing import ClientAppCallable + + +class ClientApp: + """Flower ClientApp. + + Examples + -------- + Assuming a typical `Client` implementation named `FlowerClient`, you can wrap it in + a `ClientApp` as follows: + + >>> class FlowerClient(NumPyClient): + >>> # ... + >>> + >>> def client_fn(cid): + >>> return FlowerClient().to_client() + >>> + >>> app = ClientApp(client_fn) + + If the above code is in a Python module called `client`, it can be started as + follows: + + >>> flower-client-app client:app --insecure + + In this `client:app` example, `client` refers to the Python module `client.py` in + which the previous code lives in and `app` refers to the global attribute `app` that + points to an object of type `ClientApp`. + """ + + def __init__( + self, + client_fn: Optional[ClientFn] = None, # Only for backward compatibility + mods: Optional[List[Mod]] = None, + ) -> None: + self._mods: List[Mod] = mods if mods is not None else [] + + # Create wrapper function for `handle` + self._call: Optional[ClientAppCallable] = None + if client_fn is not None: + + def ffn( + message: Message, + context: Context, + ) -> Message: # pylint: disable=invalid-name + out_message = handle_legacy_message_from_msgtype( + client_fn=client_fn, message=message, context=context + ) + return out_message + + # Wrap mods around the wrapped handle function + self._call = make_ffn(ffn, mods if mods is not None else []) + + # Step functions + self._train: Optional[ClientAppCallable] = None + self._evaluate: Optional[ClientAppCallable] = None + self._query: Optional[ClientAppCallable] = None + + def __call__(self, message: Message, context: Context) -> Message: + """Execute `ClientApp`.""" + # Execute message using `client_fn` + if self._call: + return self._call(message, context) + + # Execute message using a new + if message.metadata.message_type == MessageType.TRAIN: + if self._train: + return self._train(message, context) + raise ValueError("No `train` function registered") + if message.metadata.message_type == MessageType.EVALUATE: + if self._evaluate: + return self._evaluate(message, context) + raise ValueError("No `evaluate` function registered") + if message.metadata.message_type == MessageType.QUERY: + if self._query: + return self._query(message, context) + raise ValueError("No `query` function registered") + + # Message type did not match one of the known message types abvoe + raise ValueError(f"Unknown message_type: {message.metadata.message_type}") + + def train(self) -> Callable[[ClientAppCallable], ClientAppCallable]: + """Return a decorator that registers the train fn with the client app. + + Examples + -------- + >>> app = ClientApp() + >>> + >>> @app.train() + >>> def train(message: Message, context: Context) -> Message: + >>> print("ClientApp training running") + >>> # Create and return an echo reply message + >>> return message.create_reply(content=message.content(), ttl="") + """ + + def train_decorator(train_fn: ClientAppCallable) -> ClientAppCallable: + """Register the train fn with the ServerApp object.""" + if self._call: + raise _registration_error(MessageType.TRAIN) + + # Register provided function with the ClientApp object + # Wrap mods around the wrapped step function + self._train = make_ffn(train_fn, self._mods) + + # Return provided function unmodified + return train_fn + + return train_decorator + + def evaluate(self) -> Callable[[ClientAppCallable], ClientAppCallable]: + """Return a decorator that registers the evaluate fn with the client app. + + Examples + -------- + >>> app = ClientApp() + >>> + >>> @app.evaluate() + >>> def evaluate(message: Message, context: Context) -> Message: + >>> print("ClientApp evaluation running") + >>> # Create and return an echo reply message + >>> return message.create_reply(content=message.content(), ttl="") + """ + + def evaluate_decorator(evaluate_fn: ClientAppCallable) -> ClientAppCallable: + """Register the evaluate fn with the ServerApp object.""" + if self._call: + raise _registration_error(MessageType.EVALUATE) + + # Register provided function with the ClientApp object + # Wrap mods around the wrapped step function + self._evaluate = make_ffn(evaluate_fn, self._mods) + + # Return provided function unmodified + return evaluate_fn + + return evaluate_decorator + + def query(self) -> Callable[[ClientAppCallable], ClientAppCallable]: + """Return a decorator that registers the query fn with the client app. + + Examples + -------- + >>> app = ClientApp() + >>> + >>> @app.query() + >>> def query(message: Message, context: Context) -> Message: + >>> print("ClientApp query running") + >>> # Create and return an echo reply message + >>> return message.create_reply(content=message.content(), ttl="") + """ + + def query_decorator(query_fn: ClientAppCallable) -> ClientAppCallable: + """Register the query fn with the ServerApp object.""" + if self._call: + raise _registration_error(MessageType.QUERY) + + # Register provided function with the ClientApp object + # Wrap mods around the wrapped step function + self._query = make_ffn(query_fn, self._mods) + + # Return provided function unmodified + return query_fn + + return query_decorator + + +class LoadClientAppError(Exception): + """Error when trying to load `ClientApp`.""" + + +def _registration_error(fn_name: str) -> ValueError: + return ValueError( + f"""Use either `@app.{fn_name}()` or `client_fn`, but not both. + + Use the `ClientApp` with an existing `client_fn`: + + >>> class FlowerClient(NumPyClient): + >>> # ... + >>> + >>> def client_fn(cid) -> Client: + >>> return FlowerClient().to_client() + >>> + >>> app = ClientApp() + >>> client_fn=client_fn, + >>> ) + + Use the `ClientApp` with a custom {fn_name} function: + + >>> app = ClientApp() + >>> + >>> @app.{fn_name}() + >>> def {fn_name}(message: Message, context: Context) -> Message: + >>> print("ClientApp {fn_name} running") + >>> # Create and return an echo reply message + >>> return message.create_reply( + >>> content=message.content(), ttl="" + >>> ) + """, + ) diff --git a/src/py/flwr/client/dpfedavg_numpy_client.py b/src/py/flwr/client/dpfedavg_numpy_client.py index c39b89b31da3..ab31a289d29b 100644 --- a/src/py/flwr/client/dpfedavg_numpy_client.py +++ b/src/py/flwr/client/dpfedavg_numpy_client.py @@ -22,13 +22,20 @@ from flwr.client.numpy_client import NumPyClient from flwr.common.dp import add_gaussian_noise, clip_by_l2 +from flwr.common.logger import warn_deprecated_feature from flwr.common.typing import Config, NDArrays, Scalar class DPFedAvgNumPyClient(NumPyClient): - """Wrapper for configuring a Flower client for DP.""" + """Wrapper for configuring a Flower client for DP. + + Warning + ------- + This class is deprecated and will be removed in a future release. + """ def __init__(self, client: NumPyClient) -> None: + warn_deprecated_feature("`DPFedAvgNumPyClient` wrapper") super().__init__() self.client = client diff --git a/src/py/flwr/client/flower.py b/src/py/flwr/client/flower.py deleted file mode 100644 index 535f096e5866..000000000000 --- a/src/py/flwr/client/flower.py +++ /dev/null @@ -1,120 +0,0 @@ -# Copyright 2023 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Flower callable.""" - - -import importlib -from typing import List, Optional, cast - -from flwr.client.message_handler.message_handler import handle -from flwr.client.middleware.utils import make_ffn -from flwr.client.typing import Bwd, ClientFn, Fwd, Layer - - -class Flower: - """Flower callable. - - Examples - -------- - Assuming a typical client implementation in `FlowerClient`, you can wrap it in a - Flower callable as follows: - - >>> class FlowerClient(NumPyClient): - >>> # ... - >>> - >>> def client_fn(cid): - >>> return FlowerClient().to_client() - >>> - >>> flower = Flower(client_fn) - - If the above code is in a Python module called `client`, it can be started as - follows: - - >>> flower-client --callable client:flower - - In this `client:flower` example, `client` refers to the Python module in which the - previous code lives in. `flower` refers to the global attribute `flower` that points - to an object of type `Flower` (a Flower callable). - """ - - def __init__( - self, - client_fn: ClientFn, # Only for backward compatibility - layers: Optional[List[Layer]] = None, - ) -> None: - # Create wrapper function for `handle` - def ffn(fwd: Fwd) -> Bwd: # pylint: disable=invalid-name - task_res, state_updated = handle( - client_fn=client_fn, - state=fwd.state, - task_ins=fwd.task_ins, - ) - return Bwd(task_res=task_res, state=state_updated) - - # Wrap middleware layers around the wrapped handle function - self._call = make_ffn(ffn, layers if layers is not None else []) - - def __call__(self, fwd: Fwd) -> Bwd: - """.""" - return self._call(fwd) - - -class LoadCallableError(Exception): - """.""" - - -def load_flower_callable(module_attribute_str: str) -> Flower: - """Load the `Flower` object specified in a module attribute string. - - The module/attribute string should have the form :. Valid - examples include `client:flower` and `project.package.module:wrapper.flower`. It - must refer to a module on the PYTHONPATH, the module needs to have the specified - attribute, and the attribute must be of type `Flower`. - """ - module_str, _, attributes_str = module_attribute_str.partition(":") - if not module_str: - raise LoadCallableError( - f"Missing module in {module_attribute_str}", - ) from None - if not attributes_str: - raise LoadCallableError( - f"Missing attribute in {module_attribute_str}", - ) from None - - # Load module - try: - module = importlib.import_module(module_str) - except ModuleNotFoundError: - raise LoadCallableError( - f"Unable to load module {module_str}", - ) from None - - # Recursively load attribute - attribute = module - try: - for attribute_str in attributes_str.split("."): - attribute = getattr(attribute, attribute_str) - except AttributeError: - raise LoadCallableError( - f"Unable to load attribute {attributes_str} from module {module_str}", - ) from None - - # Check type - if not isinstance(attribute, Flower): - raise LoadCallableError( - f"Attribute {attributes_str} is not of type {Flower}", - ) from None - - return cast(Flower, attribute) diff --git a/src/py/flwr/client/grpc_client/connection.py b/src/py/flwr/client/grpc_client/connection.py index 5f11912c587c..163a58542c9e 100644 --- a/src/py/flwr/client/grpc_client/connection.py +++ b/src/py/flwr/client/grpc_client/connection.py @@ -20,15 +20,24 @@ from logging import DEBUG from pathlib import Path from queue import Queue -from typing import Callable, Iterator, Optional, Tuple, Union - -from flwr.common import GRPC_MAX_MESSAGE_LENGTH +from typing import Callable, Iterator, Optional, Tuple, Union, cast + +from flwr.common import ( + GRPC_MAX_MESSAGE_LENGTH, + ConfigsRecord, + Message, + Metadata, + RecordSet, +) +from flwr.common import recordset_compat as compat +from flwr.common import serde +from flwr.common.constant import MessageType, MessageTypeLegacy from flwr.common.grpc import create_channel from flwr.common.logger import log -from flwr.proto.node_pb2 import Node # pylint: disable=E0611 -from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 +from flwr.common.retry_invoker import RetryInvoker from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 ClientMessage, + Reason, ServerMessage, ) from flwr.proto.transport_pb2_grpc import FlowerServiceStub # pylint: disable=E0611 @@ -46,15 +55,16 @@ def on_channel_state_change(channel_connectivity: str) -> None: @contextmanager -def grpc_connection( +def grpc_connection( # pylint: disable=R0915 server_address: str, insecure: bool, + retry_invoker: RetryInvoker, # pylint: disable=unused-argument max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, root_certificates: Optional[Union[bytes, str]] = None, ) -> Iterator[ Tuple[ - Callable[[], Optional[TaskIns]], - Callable[[TaskRes], None], + Callable[[], Optional[Message]], + Callable[[Message], None], Optional[Callable[[], None]], Optional[Callable[[], None]], ] @@ -67,6 +77,11 @@ def grpc_connection( The IPv4 or IPv6 address of the server. If the Flower server runs on the same machine on port 8080, then `server_address` would be `"0.0.0.0:8080"` or `"[::]:8080"`. + insecure : bool + Starts an insecure gRPC connection when True. Enables HTTPS connection + when False, using system certificates if `root_certificates` is None. + retry_invoker: RetryInvoker + Unused argument present for compatibilty. max_message_length : int The maximum length of gRPC messages that can be exchanged with the Flower server. The default should be sufficient for most models. Users who train @@ -117,23 +132,94 @@ def grpc_connection( server_message_iterator: Iterator[ServerMessage] = stub.Join(iter(queue.get, None)) - def receive() -> TaskIns: - server_message = next(server_message_iterator) - return TaskIns( - task_id=str(uuid.uuid4()), - group_id="", - run_id=0, - task=Task( - producer=Node(node_id=0, anonymous=True), - consumer=Node(node_id=0, anonymous=True), - ancestry=[], - legacy_server_message=server_message, + def receive() -> Message: + # Receive ServerMessage proto + proto = next(server_message_iterator) + + # ServerMessage proto --> *Ins --> RecordSet + field = proto.WhichOneof("msg") + message_type = "" + if field == "get_properties_ins": + recordset = compat.getpropertiesins_to_recordset( + serde.get_properties_ins_from_proto(proto.get_properties_ins) + ) + message_type = MessageTypeLegacy.GET_PROPERTIES + elif field == "get_parameters_ins": + recordset = compat.getparametersins_to_recordset( + serde.get_parameters_ins_from_proto(proto.get_parameters_ins) + ) + message_type = MessageTypeLegacy.GET_PARAMETERS + elif field == "fit_ins": + recordset = compat.fitins_to_recordset( + serde.fit_ins_from_proto(proto.fit_ins), False + ) + message_type = MessageType.TRAIN + elif field == "evaluate_ins": + recordset = compat.evaluateins_to_recordset( + serde.evaluate_ins_from_proto(proto.evaluate_ins), False + ) + message_type = MessageType.EVALUATE + elif field == "reconnect_ins": + recordset = RecordSet() + recordset.configs_records["config"] = ConfigsRecord( + {"seconds": proto.reconnect_ins.seconds} + ) + message_type = "reconnect" + else: + raise ValueError( + "Unsupported instruction in ServerMessage, " + "cannot deserialize from ProtoBuf" + ) + + # Construct Message + return Message( + metadata=Metadata( + run_id=0, + message_id=str(uuid.uuid4()), + src_node_id=0, + dst_node_id=0, + reply_to_message="", + group_id="", + ttl="", + message_type=message_type, ), + content=recordset, ) - def send(task_res: TaskRes) -> None: - msg = task_res.task.legacy_client_message - return queue.put(msg, block=False) + def send(message: Message) -> None: + # Retrieve RecordSet and message_type + recordset = message.content + message_type = message.metadata.message_type + + # RecordSet --> *Res --> *Res proto -> ClientMessage proto + if message_type == MessageTypeLegacy.GET_PROPERTIES: + getpropres = compat.recordset_to_getpropertiesres(recordset) + msg_proto = ClientMessage( + get_properties_res=serde.get_properties_res_to_proto(getpropres) + ) + elif message_type == MessageTypeLegacy.GET_PARAMETERS: + getparamres = compat.recordset_to_getparametersres(recordset, False) + msg_proto = ClientMessage( + get_parameters_res=serde.get_parameters_res_to_proto(getparamres) + ) + elif message_type == MessageType.TRAIN: + fitres = compat.recordset_to_fitres(recordset, False) + msg_proto = ClientMessage(fit_res=serde.fit_res_to_proto(fitres)) + elif message_type == MessageType.EVALUATE: + evalres = compat.recordset_to_evaluateres(recordset) + msg_proto = ClientMessage(evaluate_res=serde.evaluate_res_to_proto(evalres)) + elif message_type == "reconnect": + reason = cast( + Reason.ValueType, recordset.configs_records["config"]["reason"] + ) + msg_proto = ClientMessage( + disconnect_res=ClientMessage.DisconnectRes(reason=reason) + ) + else: + raise ValueError(f"Invalid message type: {message_type}") + + # Send ClientMessage proto + return queue.put(msg_proto, block=False) try: # Yield methods diff --git a/src/py/flwr/client/grpc_client/connection_test.py b/src/py/flwr/client/grpc_client/connection_test.py index bcfa76bb36c0..b7737f511a2a 100644 --- a/src/py/flwr/client/grpc_client/connection_test.py +++ b/src/py/flwr/client/grpc_client/connection_test.py @@ -23,23 +23,53 @@ import grpc -from flwr.proto.task_pb2 import Task, TaskRes # pylint: disable=E0611 +from flwr.common import ConfigsRecord, Message, Metadata, RecordSet +from flwr.common import recordset_compat as compat +from flwr.common.constant import MessageTypeLegacy +from flwr.common.retry_invoker import RetryInvoker, exponential +from flwr.common.typing import Code, GetPropertiesRes, Status from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 ClientMessage, ServerMessage, ) from flwr.server.client_manager import SimpleClientManager -from flwr.server.fleet.grpc_bidi.grpc_server import start_grpc_server +from flwr.server.superlink.fleet.grpc_bidi.grpc_server import start_grpc_server from .connection import grpc_connection EXPECTED_NUM_SERVER_MESSAGE = 10 -SERVER_MESSAGE = ServerMessage() +SERVER_MESSAGE = ServerMessage(get_properties_ins=ServerMessage.GetPropertiesIns()) SERVER_MESSAGE_RECONNECT = ServerMessage(reconnect_ins=ServerMessage.ReconnectIns()) -CLIENT_MESSAGE = ClientMessage() -CLIENT_MESSAGE_DISCONNECT = ClientMessage(disconnect_res=ClientMessage.DisconnectRes()) +MESSAGE_GET_PROPERTIES = Message( + metadata=Metadata( + run_id=0, + message_id="", + src_node_id=0, + dst_node_id=0, + reply_to_message="", + group_id="", + ttl="", + message_type=MessageTypeLegacy.GET_PROPERTIES, + ), + content=compat.getpropertiesres_to_recordset( + GetPropertiesRes(Status(Code.OK, ""), {}) + ), +) +MESSAGE_DISCONNECT = Message( + metadata=Metadata( + run_id=0, + message_id="", + src_node_id=0, + dst_node_id=0, + reply_to_message="", + group_id="", + ttl="", + message_type="reconnect", + ), + content=RecordSet(configs_records={"config": ConfigsRecord({"reason": 0})}), +) def unused_tcp_port() -> int: @@ -75,7 +105,9 @@ def mock_join( # type: ignore # pylint: disable=invalid-name @patch( - "flwr.server.fleet.grpc_bidi.flower_service_servicer.FlowerServiceServicer.Join", + # pylint: disable=line-too-long + "flwr.server.superlink.fleet.grpc_bidi.flower_service_servicer.FlowerServiceServicer.Join", # noqa: E501 + # pylint: enable=line-too-long mock_join, ) def test_integration_connection() -> None: @@ -96,40 +128,30 @@ def test_integration_connection() -> None: def run_client() -> int: messages_received: int = 0 - with grpc_connection(server_address=f"[::]:{port}", insecure=True) as conn: + with grpc_connection( + server_address=f"[::]:{port}", + insecure=True, + retry_invoker=RetryInvoker( + wait_factory=exponential, + recoverable_exceptions=grpc.RpcError, + max_tries=1, + max_time=None, + ), + ) as conn: receive, send, _, _ = conn # Setup processing loop while True: # Block until server responds with a message - task_ins = receive() - - if task_ins is None: - raise ValueError("Unexpected None value") - - # pylint: disable=no-member - if task_ins.HasField("task") and task_ins.task.HasField( - "legacy_server_message" - ): - server_message = task_ins.task.legacy_server_message - else: - server_message = None - # pylint: enable=no-member - - if server_message is None: - raise ValueError("Unexpected None value") + message = receive() messages_received += 1 - if server_message.HasField("reconnect_ins"): - task_res = TaskRes( - task=Task(legacy_client_message=CLIENT_MESSAGE_DISCONNECT) - ) - send(task_res) + if message.metadata.message_type == "reconnect": # type: ignore + send(MESSAGE_DISCONNECT) break # Process server_message and send client_message... - task_res = TaskRes(task=Task(legacy_client_message=CLIENT_MESSAGE)) - send(task_res) + send(MESSAGE_GET_PROPERTIES) return messages_received diff --git a/src/py/flwr/client/grpc_rere_client/connection.py b/src/py/flwr/client/grpc_rere_client/connection.py index cb1a7021dc9d..e6e22998b947 100644 --- a/src/py/flwr/client/grpc_rere_client/connection.py +++ b/src/py/flwr/client/grpc_rere_client/connection.py @@ -16,19 +16,19 @@ from contextlib import contextmanager +from copy import copy from logging import DEBUG, ERROR from pathlib import Path from typing import Callable, Dict, Iterator, Optional, Tuple, Union, cast -from flwr.client.message_handler.task_handler import ( - configure_task_res, - get_task_ins, - validate_task_ins, - validate_task_res, -) +from flwr.client.message_handler.message_handler import validate_out_message +from flwr.client.message_handler.task_handler import get_task_ins, validate_task_ins from flwr.common import GRPC_MAX_MESSAGE_LENGTH from flwr.common.grpc import create_channel from flwr.common.logger import log, warn_experimental_feature +from flwr.common.message import Message, Metadata +from flwr.common.retry_invoker import RetryInvoker +from flwr.common.serde import message_from_taskins, message_to_taskres from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, DeleteNodeRequest, @@ -37,10 +37,10 @@ ) from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611 from flwr.proto.node_pb2 import Node # pylint: disable=E0611 -from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 +from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611 KEY_NODE = "node" -KEY_TASK_INS = "current_task_ins" +KEY_METADATA = "in_message_metadata" def on_channel_state_change(channel_connectivity: str) -> None: @@ -52,12 +52,13 @@ def on_channel_state_change(channel_connectivity: str) -> None: def grpc_request_response( server_address: str, insecure: bool, + retry_invoker: RetryInvoker, max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613 root_certificates: Optional[Union[bytes, str]] = None, ) -> Iterator[ Tuple[ - Callable[[], Optional[TaskIns]], - Callable[[TaskRes], None], + Callable[[], Optional[Message]], + Callable[[Message], None], Optional[Callable[[], None]], Optional[Callable[[], None]], ] @@ -73,6 +74,13 @@ def grpc_request_response( The IPv6 address of the server with `http://` or `https://`. If the Flower server runs on the same machine on port 8080, then `server_address` would be `"http://[::]:8080"`. + insecure : bool + Starts an insecure gRPC connection when True. Enables HTTPS connection + when False, using system certificates if `root_certificates` is None. + retry_invoker: RetryInvoker + `RetryInvoker` object that will try to reconnect the client to the server + after gRPC errors. If None, the client will only try to + reconnect once after a failure. max_message_length : int Ignored, only present to preserve API-compatibility. root_certificates : Optional[Union[bytes, str]] (default: None) @@ -101,8 +109,8 @@ def grpc_request_response( channel.subscribe(on_channel_state_change) stub = FleetStub(channel) - # Necessary state to link TaskRes to TaskIns - state: Dict[str, Optional[TaskIns]] = {KEY_TASK_INS: None} + # Necessary state to validate messages to be sent + state: Dict[str, Optional[Metadata]] = {KEY_METADATA: None} # Enable create_node and delete_node to store node node_store: Dict[str, Optional[Node]] = {KEY_NODE: None} @@ -114,7 +122,8 @@ def grpc_request_response( def create_node() -> None: """Set create_node.""" create_node_request = CreateNodeRequest() - create_node_response = stub.CreateNode( + create_node_response = retry_invoker.invoke( + stub.CreateNode, request=create_node_request, ) node_store[KEY_NODE] = create_node_response.node @@ -128,11 +137,11 @@ def delete_node() -> None: node: Node = cast(Node, node_store[KEY_NODE]) delete_node_request = DeleteNodeRequest(node=node) - stub.DeleteNode(request=delete_node_request) + retry_invoker.invoke(stub.DeleteNode, request=delete_node_request) del node_store[KEY_NODE] - def receive() -> Optional[TaskIns]: + def receive() -> Optional[Message]: """Receive next task from server.""" # Get Node if node_store[KEY_NODE] is None: @@ -142,50 +151,53 @@ def receive() -> Optional[TaskIns]: # Request instructions (task) from server request = PullTaskInsRequest(node=node) - response = stub.PullTaskIns(request=request) + response = retry_invoker.invoke(stub.PullTaskIns, request=request) # Get the current TaskIns task_ins: Optional[TaskIns] = get_task_ins(response) # Discard the current TaskIns if not valid - if task_ins is not None and not validate_task_ins( - task_ins, discard_reconnect_ins=True + if task_ins is not None and not ( + task_ins.task.consumer.node_id == node.node_id + and validate_task_ins(task_ins) ): task_ins = None - # Remember `task_ins` until `task_res` is available - state[KEY_TASK_INS] = task_ins + # Construct the Message + in_message = message_from_taskins(task_ins) if task_ins else None + + # Remember `metadata` of the in message + state[KEY_METADATA] = copy(in_message.metadata) if in_message else None - # Return the TaskIns if available - return task_ins + # Return the message if available + return in_message - def send(task_res: TaskRes) -> None: + def send(message: Message) -> None: """Send task result back to server.""" # Get Node if node_store[KEY_NODE] is None: log(ERROR, "Node instance missing") return - node: Node = cast(Node, node_store[KEY_NODE]) - # Get incoming TaskIns - if state[KEY_TASK_INS] is None: - log(ERROR, "No current TaskIns") + # Get incoming message + in_metadata = state[KEY_METADATA] + if in_metadata is None: + log(ERROR, "No current message") return - task_ins: TaskIns = cast(TaskIns, state[KEY_TASK_INS]) - # Check if fields to be set are not initialized - if not validate_task_res(task_res): - state[KEY_TASK_INS] = None - log(ERROR, "TaskRes has been initialized accidentally") + # Validate out message + if not validate_out_message(message, in_metadata): + log(ERROR, "Invalid out message") + return - # Configure TaskRes - task_res = configure_task_res(task_res, task_ins, node) + # Construct TaskRes + task_res = message_to_taskres(message) # Serialize ProtoBuf to bytes request = PushTaskResRequest(task_res_list=[task_res]) - _ = stub.PushTaskRes(request) + _ = retry_invoker.invoke(stub.PushTaskRes, request) - state[KEY_TASK_INS] = None + state[KEY_METADATA] = None try: # Yield methods diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index 8cfe909c1738..9a5d70b1ac4d 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -15,28 +15,28 @@ """Client-side message handler.""" -from typing import Optional, Tuple +from logging import WARN +from typing import Optional, Tuple, cast from flwr.client.client import ( - Client, maybe_call_evaluate, maybe_call_fit, maybe_call_get_parameters, maybe_call_get_properties, ) -from flwr.client.message_handler.task_handler import ( - get_server_message_from_task_ins, - wrap_client_message_in_task_res, -) -from flwr.client.run_state import RunState -from flwr.client.secure_aggregation import SecureAggregationHandler +from flwr.client.numpy_client import NumPyClient from flwr.client.typing import ClientFn -from flwr.common import serde -from flwr.proto.task_pb2 import ( # pylint: disable=E0611 - SecureAggregation, - Task, - TaskIns, - TaskRes, +from flwr.common import ConfigsRecord, Context, Message, Metadata, RecordSet, log +from flwr.common.constant import MessageType, MessageTypeLegacy +from flwr.common.recordset_compat import ( + evaluateres_to_recordset, + fitres_to_recordset, + getparametersres_to_recordset, + getpropertiesres_to_recordset, + recordset_to_evaluateins, + recordset_to_fitins, + recordset_to_getparametersins, + recordset_to_getpropertiesins, ) from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 ClientMessage, @@ -53,128 +53,97 @@ class UnknownServerMessage(Exception): """Exception indicating that the received message is unknown.""" -def handle_control_message(task_ins: TaskIns) -> Tuple[Optional[TaskRes], int]: +def handle_control_message(message: Message) -> Tuple[Optional[Message], int]: """Handle control part of the incoming message. Parameters ---------- - task_ins : TaskIns - The task instruction coming from the server, to be processed by the client. + message : Message + The Message coming from the server, to be processed by the client. Returns ------- + message : Optional[Message] + Message to be sent back to the server. If None, the client should + continue to process messages from the server. sleep_duration : int Number of seconds that the client should disconnect from the server. - keep_going : bool - Flag that indicates whether the client should continue to process the - next message from the server (True) or disconnect and optionally - reconnect later (False). """ - server_msg = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=False) - - # SecAgg message - if server_msg is None: - return None, 0 - - # ReconnectIns message - field = server_msg.WhichOneof("msg") - if field == "reconnect_ins": - disconnect_msg, sleep_duration = _reconnect(server_msg.reconnect_ins) - task_res = wrap_client_message_in_task_res(disconnect_msg) - return task_res, sleep_duration + if message.metadata.message_type == "reconnect": + # Retrieve ReconnectIns from recordset + recordset = message.content + seconds = cast(int, recordset.configs_records["config"]["seconds"]) + # Construct ReconnectIns and call _reconnect + disconnect_msg, sleep_duration = _reconnect( + ServerMessage.ReconnectIns(seconds=seconds) + ) + # Store DisconnectRes in recordset + reason = cast(int, disconnect_msg.disconnect_res.reason) + recordset = RecordSet() + recordset.configs_records["config"] = ConfigsRecord({"reason": reason}) + out_message = message.create_reply(recordset, ttl="") + # Return TaskRes and sleep duration + return out_message, sleep_duration # Any other message return None, 0 -def handle( - client_fn: ClientFn, state: RunState, task_ins: TaskIns -) -> Tuple[TaskRes, RunState]: - """Handle incoming TaskIns from the server. - - Parameters - ---------- - client_fn : ClientFn - A callable that instantiates a Client. - state : RunState - A dataclass storing the state for the run being executed by the client. - task_ins: TaskIns - The task instruction coming from the server, to be processed by the client. - - Returns - ------- - task_res : TaskRes - The task response that should be returned to the server. - """ - server_msg = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=False) - if server_msg is None: - # Instantiate the client - client = client_fn("-1") - client.set_state(state) - # Secure Aggregation - if task_ins.task.HasField("sa") and isinstance( - client, SecureAggregationHandler - ): - # pylint: disable-next=invalid-name - named_values = serde.named_values_from_proto(task_ins.task.sa.named_values) - res = client.handle_secure_aggregation(named_values) - task_res = TaskRes( - task_id="", - group_id="", - run_id=0, - task=Task( - ancestry=[], - sa=SecureAggregation(named_values=serde.named_values_to_proto(res)), - ), - ) - return task_res, client.get_state() - raise NotImplementedError() - client_msg, updated_state = handle_legacy_message(client_fn, state, server_msg) - task_res = wrap_client_message_in_task_res(client_msg) - return task_res, updated_state - - -def handle_legacy_message( - client_fn: ClientFn, state: RunState, server_msg: ServerMessage -) -> Tuple[ClientMessage, RunState]: - """Handle incoming messages from the server. - - Parameters - ---------- - client_fn : ClientFn - A callable that instantiates a Client. - state : RunState - A dataclass storing the state for the run being executed by the client. - server_msg: ServerMessage - The message coming from the server, to be processed by the client. - - Returns - ------- - client_msg : ClientMessage - The result message that should be returned to the server. - """ - field = server_msg.WhichOneof("msg") - - # Must be handled elsewhere - if field == "reconnect_ins": - raise UnexpectedServerMessage() - - # Instantiate the client - client = client_fn("-1") - client.set_state(state) - # Execute task - message = None - if field == "get_properties_ins": - message = _get_properties(client, server_msg.get_properties_ins) - if field == "get_parameters_ins": - message = _get_parameters(client, server_msg.get_parameters_ins) - if field == "fit_ins": - message = _fit(client, server_msg.fit_ins) - if field == "evaluate_ins": - message = _evaluate(client, server_msg.evaluate_ins) - if message: - return message, client.get_state() - raise UnknownServerMessage() +def handle_legacy_message_from_msgtype( + client_fn: ClientFn, message: Message, context: Context +) -> Message: + """Handle legacy message in the inner most mod.""" + client = client_fn(str(message.metadata.partition_id)) + + # Check if NumPyClient is returend + if isinstance(client, NumPyClient): + client = client.to_client() + log( + WARN, + "Deprecation Warning: The `client_fn` function must return an instance " + "of `Client`, but an instance of `NumpyClient` was returned. " + "Please use `NumPyClient.to_client()` method to convert it to `Client`.", + ) + + client.set_context(context) + + message_type = message.metadata.message_type + + # Handle GetPropertiesIns + if message_type == MessageTypeLegacy.GET_PROPERTIES: + get_properties_res = maybe_call_get_properties( + client=client, + get_properties_ins=recordset_to_getpropertiesins(message.content), + ) + out_recordset = getpropertiesres_to_recordset(get_properties_res) + # Handle GetParametersIns + elif message_type == MessageTypeLegacy.GET_PARAMETERS: + get_parameters_res = maybe_call_get_parameters( + client=client, + get_parameters_ins=recordset_to_getparametersins(message.content), + ) + out_recordset = getparametersres_to_recordset( + get_parameters_res, keep_input=False + ) + # Handle FitIns + elif message_type == MessageType.TRAIN: + fit_res = maybe_call_fit( + client=client, + fit_ins=recordset_to_fitins(message.content, keep_input=True), + ) + out_recordset = fitres_to_recordset(fit_res, keep_input=False) + # Handle EvaluateIns + elif message_type == MessageType.EVALUATE: + evaluate_res = maybe_call_evaluate( + client=client, + evaluate_ins=recordset_to_evaluateins(message.content, keep_input=True), + ) + out_recordset = evaluateres_to_recordset(evaluate_res) + else: + raise ValueError(f"Invalid message type: {message_type}") + + # Return Message + return message.create_reply(out_recordset, ttl="") def _reconnect( @@ -191,65 +160,18 @@ def _reconnect( return ClientMessage(disconnect_res=disconnect_res), sleep_duration -def _get_properties( - client: Client, get_properties_msg: ServerMessage.GetPropertiesIns -) -> ClientMessage: - # Deserialize `get_properties` instruction - get_properties_ins = serde.get_properties_ins_from_proto(get_properties_msg) - - # Request properties - get_properties_res = maybe_call_get_properties( - client=client, - get_properties_ins=get_properties_ins, - ) - - # Serialize response - get_properties_res_proto = serde.get_properties_res_to_proto(get_properties_res) - return ClientMessage(get_properties_res=get_properties_res_proto) - - -def _get_parameters( - client: Client, get_parameters_msg: ServerMessage.GetParametersIns -) -> ClientMessage: - # Deserialize `get_parameters` instruction - get_parameters_ins = serde.get_parameters_ins_from_proto(get_parameters_msg) - - # Request parameters - get_parameters_res = maybe_call_get_parameters( - client=client, - get_parameters_ins=get_parameters_ins, - ) - - # Serialize response - get_parameters_res_proto = serde.get_parameters_res_to_proto(get_parameters_res) - return ClientMessage(get_parameters_res=get_parameters_res_proto) - - -def _fit(client: Client, fit_msg: ServerMessage.FitIns) -> ClientMessage: - # Deserialize fit instruction - fit_ins = serde.fit_ins_from_proto(fit_msg) - - # Perform fit - fit_res = maybe_call_fit( - client=client, - fit_ins=fit_ins, - ) - - # Serialize fit result - fit_res_proto = serde.fit_res_to_proto(fit_res) - return ClientMessage(fit_res=fit_res_proto) - - -def _evaluate(client: Client, evaluate_msg: ServerMessage.EvaluateIns) -> ClientMessage: - # Deserialize evaluate instruction - evaluate_ins = serde.evaluate_ins_from_proto(evaluate_msg) - - # Perform evaluation - evaluate_res = maybe_call_evaluate( - client=client, - evaluate_ins=evaluate_ins, - ) - - # Serialize evaluate result - evaluate_res_proto = serde.evaluate_res_to_proto(evaluate_res) - return ClientMessage(evaluate_res=evaluate_res_proto) +def validate_out_message(out_message: Message, in_message_metadata: Metadata) -> bool: + """Validate the out message.""" + out_meta = out_message.metadata + in_meta = in_message_metadata + if ( # pylint: disable-next=too-many-boolean-expressions + out_meta.run_id == in_meta.run_id + and out_meta.message_id == "" # This will be generated by the server + and out_meta.src_node_id == in_meta.dst_node_id + and out_meta.dst_node_id == in_meta.src_node_id + and out_meta.reply_to_message == in_meta.message_id + and out_meta.group_id == in_meta.group_id + and out_meta.message_type == in_meta.message_type + ): + return True + return False diff --git a/src/py/flwr/client/message_handler/message_handler_test.py b/src/py/flwr/client/message_handler/message_handler_test.py index 194f75fe30ca..eaf16f7dc993 100644 --- a/src/py/flwr/client/message_handler/message_handler_test.py +++ b/src/py/flwr/client/message_handler/message_handler_test.py @@ -15,12 +15,16 @@ """Client-side message handler tests.""" +import unittest import uuid +from copy import copy +from typing import List from flwr.client import Client -from flwr.client.run_state import RunState from flwr.client.typing import ClientFn from flwr.common import ( + Code, + Context, EvaluateIns, EvaluateRes, FitIns, @@ -29,20 +33,17 @@ GetParametersRes, GetPropertiesIns, GetPropertiesRes, + Message, + Metadata, Parameters, - serde, - typing, -) -from flwr.proto.node_pb2 import Node # pylint: disable=E0611 -from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 -from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 - ClientMessage, - Code, - ServerMessage, + RecordSet, Status, ) +from flwr.common import recordset_compat as compat +from flwr.common import typing +from flwr.common.constant import MessageTypeLegacy -from .message_handler import handle, handle_control_message +from .message_handler import handle_legacy_message_from_msgtype, validate_out_message class ClientWithoutProps(Client): @@ -121,137 +122,169 @@ def test_client_without_get_properties() -> None: """Test client implementing get_properties.""" # Prepare client = ClientWithoutProps() - ins = ServerMessage.GetPropertiesIns() - - task_ins: TaskIns = TaskIns( - task_id=str(uuid.uuid4()), - group_id="", - run_id=0, - task=Task( - producer=Node(node_id=0, anonymous=True), - consumer=Node(node_id=0, anonymous=True), - ancestry=[], - legacy_server_message=ServerMessage(get_properties_ins=ins), + recordset = compat.getpropertiesins_to_recordset(GetPropertiesIns({})) + message = Message( + metadata=Metadata( + run_id=123, + message_id=str(uuid.uuid4()), + group_id="some group ID", + src_node_id=0, + dst_node_id=1123, + reply_to_message="", + ttl="", + message_type=MessageTypeLegacy.GET_PROPERTIES, ), + content=recordset, ) # Execute - disconnect_task_res, actual_sleep_duration = handle_control_message( - task_ins=task_ins - ) - task_res, _ = handle( + actual_msg = handle_legacy_message_from_msgtype( client_fn=_get_client_fn(client), - state=RunState(state={}), - task_ins=task_ins, - ) - - if not task_res.HasField("task"): - raise ValueError("Task value not found") - - # pylint: disable=no-member - if not task_res.task.HasField("legacy_client_message"): - raise ValueError("Unexpected None value") - # pylint: enable=no-member - - task_res.MergeFrom( - TaskRes( - task_id=str(uuid.uuid4()), - group_id="", - run_id=0, - ) - ) - # pylint: disable=no-member - task_res.task.MergeFrom( - Task( - producer=Node(node_id=0, anonymous=True), - consumer=Node(node_id=0, anonymous=True), - ancestry=[task_ins.task_id], - ) + message=message, + context=Context(state=RecordSet()), ) - actual_msg = task_res.task.legacy_client_message - # pylint: enable=no-member - # Assert - expected_get_properties_res = ClientMessage.GetPropertiesRes( + expected_get_properties_res = GetPropertiesRes( status=Status( code=Code.GET_PROPERTIES_NOT_IMPLEMENTED, message="Client does not implement `get_properties`", - ) + ), + properties={}, + ) + expected_rs = compat.getpropertiesres_to_recordset(expected_get_properties_res) + expected_msg = Message( + metadata=Metadata( + run_id=123, + message_id="", + group_id="some group ID", + src_node_id=1123, + dst_node_id=0, + reply_to_message=message.metadata.message_id, + ttl="", + message_type=MessageTypeLegacy.GET_PROPERTIES, + ), + content=expected_rs, ) - expected_msg = ClientMessage(get_properties_res=expected_get_properties_res) - assert actual_msg == expected_msg - assert not disconnect_task_res - assert actual_sleep_duration == 0 + assert actual_msg.content == expected_msg.content + assert actual_msg.metadata == expected_msg.metadata def test_client_with_get_properties() -> None: """Test client not implementing get_properties.""" # Prepare client = ClientWithProps() - ins = ServerMessage.GetPropertiesIns() - task_ins = TaskIns( - task_id=str(uuid.uuid4()), - group_id="", - run_id=0, - task=Task( - producer=Node(node_id=0, anonymous=True), - consumer=Node(node_id=0, anonymous=True), - ancestry=[], - legacy_server_message=ServerMessage(get_properties_ins=ins), + recordset = compat.getpropertiesins_to_recordset(GetPropertiesIns({})) + message = Message( + metadata=Metadata( + run_id=123, + message_id=str(uuid.uuid4()), + group_id="some group ID", + src_node_id=0, + dst_node_id=1123, + reply_to_message="", + ttl="", + message_type=MessageTypeLegacy.GET_PROPERTIES, ), + content=recordset, ) # Execute - disconnect_task_res, actual_sleep_duration = handle_control_message( - task_ins=task_ins - ) - task_res, _ = handle( + actual_msg = handle_legacy_message_from_msgtype( client_fn=_get_client_fn(client), - state=RunState(state={}), - task_ins=task_ins, - ) - - if not task_res.HasField("task"): - raise ValueError("Task value not found") - - # pylint: disable=no-member - if not task_res.task.HasField("legacy_client_message"): - raise ValueError("Unexpected None value") - # pylint: enable=no-member - - task_res.MergeFrom( - TaskRes( - task_id=str(uuid.uuid4()), - group_id="", - run_id=0, - ) - ) - # pylint: disable=no-member - task_res.task.MergeFrom( - Task( - producer=Node(node_id=0, anonymous=True), - consumer=Node(node_id=0, anonymous=True), - ancestry=[task_ins.task_id], - ) + message=message, + context=Context(state=RecordSet()), ) - actual_msg = task_res.task.legacy_client_message - # pylint: enable=no-member - # Assert - expected_get_properties_res = ClientMessage.GetPropertiesRes( + expected_get_properties_res = GetPropertiesRes( status=Status( code=Code.OK, message="Success", ), - properties=serde.properties_to_proto( - properties={"str_prop": "val", "int_prop": 1} + properties={"str_prop": "val", "int_prop": 1}, + ) + expected_rs = compat.getpropertiesres_to_recordset(expected_get_properties_res) + expected_msg = Message( + metadata=Metadata( + run_id=123, + message_id="", + group_id="some group ID", + src_node_id=1123, + dst_node_id=0, + reply_to_message=message.metadata.message_id, + ttl="", + message_type=MessageTypeLegacy.GET_PROPERTIES, ), + content=expected_rs, ) - expected_msg = ClientMessage(get_properties_res=expected_get_properties_res) - assert actual_msg == expected_msg - assert not disconnect_task_res - assert actual_sleep_duration == 0 + assert actual_msg.content == expected_msg.content + assert actual_msg.metadata == expected_msg.metadata + + +class TestMessageValidation(unittest.TestCase): + """Test message validation.""" + + def setUp(self) -> None: + """Set up the message validation.""" + # Common setup for tests + self.in_metadata = Metadata( + run_id=123, + message_id="qwerty", + src_node_id=10, + dst_node_id=20, + reply_to_message="", + group_id="group1", + ttl="60", + message_type="mock", + ) + self.valid_out_metadata = Metadata( + run_id=123, + message_id="", + src_node_id=20, + dst_node_id=10, + reply_to_message="qwerty", + group_id="group1", + ttl="60", + message_type="mock", + ) + self.common_content = RecordSet() + + def test_valid_message(self) -> None: + """Test a valid message.""" + # Prepare + valid_message = Message(metadata=self.valid_out_metadata, content=RecordSet()) + + # Assert + self.assertTrue(validate_out_message(valid_message, self.in_metadata)) + + def test_invalid_message_run_id(self) -> None: + """Test invalid messages.""" + # Prepare + msg = Message(metadata=self.valid_out_metadata, content=RecordSet()) + + # Execute + invalid_metadata_list: List[Metadata] = [] + attrs = list(vars(self.valid_out_metadata).keys()) + for attr in attrs: + if attr == "_partition_id": + continue + if attr == "_ttl": # Skip configurable ttl + continue + # Make an invalid metadata + invalid_metadata = copy(self.valid_out_metadata) + value = getattr(invalid_metadata, attr) + if isinstance(value, int): + value = 999 + elif isinstance(value, str): + value = "999" + setattr(invalid_metadata, attr, value) + # Add to list + invalid_metadata_list.append(invalid_metadata) + + # Assert + for invalid_metadata in invalid_metadata_list: + msg._metadata = invalid_metadata # pylint: disable=protected-access + self.assertFalse(validate_out_message(msg, self.in_metadata)) diff --git a/src/py/flwr/client/message_handler/task_handler.py b/src/py/flwr/client/message_handler/task_handler.py index 667cb9c98d46..7f515a30fe5a 100644 --- a/src/py/flwr/client/message_handler/task_handler.py +++ b/src/py/flwr/client/message_handler/task_handler.py @@ -18,23 +18,16 @@ from typing import Optional from flwr.proto.fleet_pb2 import PullTaskInsResponse # pylint: disable=E0611 -from flwr.proto.node_pb2 import Node # pylint: disable=E0611 -from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 -from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 - ClientMessage, - ServerMessage, -) +from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611 -def validate_task_ins(task_ins: TaskIns, discard_reconnect_ins: bool) -> bool: +def validate_task_ins(task_ins: TaskIns) -> bool: """Validate a TaskIns before it entering the message handling process. Parameters ---------- task_ins: TaskIns The task instruction coming from the server. - discard_reconnect_ins: bool - If True, ReconnectIns will not be considered as valid content. Returns ------- @@ -42,57 +35,8 @@ def validate_task_ins(task_ins: TaskIns, discard_reconnect_ins: bool) -> bool: True if the TaskIns is deemed valid and therefore suitable for undergoing the message handling process, False otherwise. """ - # Check if the task_ins contains legacy_server_message or sa. - # If legacy_server_message is set, check if ServerMessage is one of - # {GetPropertiesIns, GetParametersIns, FitIns, EvaluateIns, ReconnectIns*} - # Discard ReconnectIns if discard_reconnect_ins is true. - if ( - not task_ins.HasField("task") - or ( - not task_ins.task.HasField("legacy_server_message") - and not task_ins.task.HasField("sa") - ) - or ( - discard_reconnect_ins - and task_ins.task.legacy_server_message.WhichOneof("msg") == "reconnect_ins" - ) - ): + if not (task_ins.HasField("task") and task_ins.task.HasField("recordset")): return False - - return True - - -def validate_task_res(task_res: TaskRes) -> bool: - """Validate a TaskRes before filling its fields in the `send()` function. - - Parameters - ---------- - task_res: TaskRes - The task response to be sent to the server. - - Returns - ------- - is_valid: bool - True if the `task_id`, `group_id`, and `run_id` fields in TaskRes - and the `producer`, `consumer`, and `ancestry` fields in its sub-message Task - are not initialized accidentally elsewhere, - False otherwise. - """ - # Retrieve initialized fields in TaskRes and Task - initialized_fields_in_task_res = {field.name for field, _ in task_res.ListFields()} - initialized_fields_in_task = {field.name for field, _ in task_res.task.ListFields()} - - # Check if certain fields are already initialized - if ( # pylint: disable-next=too-many-boolean-expressions - "task_id" in initialized_fields_in_task_res - or "group_id" in initialized_fields_in_task_res - or "run_id" in initialized_fields_in_task_res - or "producer" in initialized_fields_in_task - or "consumer" in initialized_fields_in_task - or "ancestry" in initialized_fields_in_task - ): - return False - return True @@ -108,61 +52,3 @@ def get_task_ins( task_ins: TaskIns = pull_task_ins_response.task_ins_list[0] return task_ins - - -def get_server_message_from_task_ins( - task_ins: TaskIns, exclude_reconnect_ins: bool -) -> Optional[ServerMessage]: - """Get ServerMessage from TaskIns, if available.""" - # Return the message if it is in - # {GetPropertiesIns, GetParametersIns, FitIns, EvaluateIns} - # Return the message if it is ReconnectIns and exclude_reconnect_ins is False. - if not validate_task_ins( - task_ins, discard_reconnect_ins=exclude_reconnect_ins - ) or not task_ins.task.HasField("legacy_server_message"): - return None - - return task_ins.task.legacy_server_message - - -def wrap_client_message_in_task_res(client_message: ClientMessage) -> TaskRes: - """Wrap ClientMessage in TaskRes.""" - # Instantiate a TaskRes, only filling client_message field. - return TaskRes( - task_id="", - group_id="", - run_id=0, - task=Task(ancestry=[], legacy_client_message=client_message), - ) - - -def configure_task_res( - task_res: TaskRes, ref_task_ins: TaskIns, producer: Node -) -> TaskRes: - """Set the metadata of a TaskRes. - - Fill `group_id` and `run_id` in TaskRes - and `producer`, `consumer`, and `ancestry` in Task in TaskRes. - - `producer` in Task in TaskRes will remain unchanged/unset. - - Note that protobuf API `protobuf.message.MergeFrom(other_msg)` - does NOT always overwrite fields that are set in `other_msg`. - Please refer to: - https://googleapis.dev/python/protobuf/latest/google/protobuf/message.html - """ - task_res = TaskRes( - task_id="", # This will be generated by the server - group_id=ref_task_ins.group_id, - run_id=ref_task_ins.run_id, - task=task_res.task, - ) - # pylint: disable-next=no-member - task_res.task.MergeFrom( - Task( - producer=producer, - consumer=ref_task_ins.task.producer, - ancestry=[ref_task_ins.task_id], - ) - ) - return task_res diff --git a/src/py/flwr/client/message_handler/task_handler_test.py b/src/py/flwr/client/message_handler/task_handler_test.py index c1111d0935c0..c8b9e14737ff 100644 --- a/src/py/flwr/client/message_handler/task_handler_test.py +++ b/src/py/flwr/client/message_handler/task_handler_test.py @@ -15,108 +15,31 @@ """Tests for module task_handler.""" -from flwr.client.message_handler.task_handler import ( - get_server_message_from_task_ins, - get_task_ins, - validate_task_ins, - validate_task_res, - wrap_client_message_in_task_res, -) +from flwr.client.message_handler.task_handler import get_task_ins, validate_task_ins +from flwr.common import RecordSet, serde from flwr.proto.fleet_pb2 import PullTaskInsResponse # pylint: disable=E0611 -from flwr.proto.task_pb2 import ( # pylint: disable=E0611 - SecureAggregation, - Task, - TaskIns, - TaskRes, -) -from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 - ClientMessage, - ServerMessage, -) +from flwr.proto.task_pb2 import Task, TaskIns # pylint: disable=E0611 def test_validate_task_ins_no_task() -> None: """Test validate_task_ins.""" task_ins = TaskIns(task=None) - assert not validate_task_ins(task_ins, discard_reconnect_ins=True) - assert not validate_task_ins(task_ins, discard_reconnect_ins=False) + assert not validate_task_ins(task_ins) def test_validate_task_ins_no_content() -> None: """Test validate_task_ins.""" - task_ins = TaskIns(task=Task(legacy_server_message=None, sa=None)) + task_ins = TaskIns(task=Task(recordset=None)) - assert not validate_task_ins(task_ins, discard_reconnect_ins=True) - assert not validate_task_ins(task_ins, discard_reconnect_ins=False) + assert not validate_task_ins(task_ins) -def test_validate_task_ins_with_reconnect_ins() -> None: +def test_validate_task_ins_valid() -> None: """Test validate_task_ins.""" - task_ins = TaskIns( - task=Task( - legacy_server_message=ServerMessage( - reconnect_ins=ServerMessage.ReconnectIns(seconds=3) - ) - ) - ) - - assert not validate_task_ins(task_ins, discard_reconnect_ins=True) - assert validate_task_ins(task_ins, discard_reconnect_ins=False) - - -def test_validate_task_ins_valid_legacy_server_message() -> None: - """Test validate_task_ins.""" - task_ins = TaskIns( - task=Task( - legacy_server_message=ServerMessage( - get_properties_ins=ServerMessage.GetPropertiesIns() - ) - ) - ) - - assert validate_task_ins(task_ins, discard_reconnect_ins=True) - assert validate_task_ins(task_ins, discard_reconnect_ins=False) - - -def test_validate_task_ins_valid_sa() -> None: - """Test validate_task_ins.""" - task_ins = TaskIns(task=Task(sa=SecureAggregation())) - - assert validate_task_ins(task_ins, discard_reconnect_ins=True) - assert validate_task_ins(task_ins, discard_reconnect_ins=False) - - -def test_validate_task_res() -> None: - """Test validate_task_res.""" - task_res = TaskRes(task=Task()) - assert validate_task_res(task_res) - - task_res.task_id = "123" - assert not validate_task_res(task_res) - - task_res.Clear() - task_res.group_id = "123" - assert not validate_task_res(task_res) - - task_res.Clear() - task_res.run_id = 61016 - assert not validate_task_res(task_res) - - task_res.Clear() - # pylint: disable-next=no-member - task_res.task.producer.node_id = 0 - assert not validate_task_res(task_res) - - task_res.Clear() - # pylint: disable-next=no-member - task_res.task.consumer.node_id = 0 - assert not validate_task_res(task_res) + task_ins = TaskIns(task=Task(recordset=serde.recordset_to_proto(RecordSet()))) - task_res.Clear() - # pylint: disable-next=no-member - task_res.task.ancestry.append("123") - assert not validate_task_res(task_res) + assert validate_task_ins(task_ins) def test_get_task_ins_empty_response() -> None: @@ -142,61 +65,3 @@ def test_get_task_ins_multiple_ins() -> None: ) actual_task_ins = get_task_ins(res) assert actual_task_ins == expected_task_ins - - -def test_get_server_message_from_task_ins_invalid() -> None: - """Test get_server_message_from_task_ins.""" - task_ins = TaskIns(task=Task(legacy_server_message=None)) - msg_t = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=True) - msg_f = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=False) - - assert msg_t is None - assert msg_f is None - - -def test_get_server_message_from_task_ins_reconnect_ins() -> None: - """Test get_server_message_from_task_ins.""" - expected_server_message = ServerMessage( - reconnect_ins=ServerMessage.ReconnectIns(seconds=3) - ) - task_ins = TaskIns(task=Task(legacy_server_message=expected_server_message)) - msg_t = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=True) - msg_f = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=False) - - assert msg_t is None - assert msg_f == expected_server_message - - -def test_get_server_message_from_task_ins_sa() -> None: - """Test get_server_message_from_task_ins.""" - task_ins = TaskIns(task=Task(sa=SecureAggregation())) - msg_t = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=True) - msg_f = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=False) - - assert msg_t is None - assert msg_f is None - - -def test_get_server_message_from_task_ins_valid_legacy_server_message() -> None: - """Test get_server_message_from_task_ins.""" - expected_server_message = ServerMessage( - get_properties_ins=ServerMessage.GetPropertiesIns() - ) - task_ins = TaskIns(task=Task(legacy_server_message=expected_server_message)) - msg_t = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=True) - msg_f = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=False) - - assert msg_t == expected_server_message - assert msg_f == expected_server_message - - -def test_wrap_client_message_in_task_res() -> None: - """Test wrap_client_message_in_task_res.""" - expected_client_message = ClientMessage( - get_properties_res=ClientMessage.GetPropertiesRes() - ) - task_res = wrap_client_message_in_task_res(expected_client_message) - - assert validate_task_res(task_res) - # pylint: disable-next=no-member - assert task_res.task.legacy_client_message == expected_client_message diff --git a/src/py/flwr/client/middleware/utils_test.py b/src/py/flwr/client/middleware/utils_test.py deleted file mode 100644 index 006fe6db4799..000000000000 --- a/src/py/flwr/client/middleware/utils_test.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright 2023 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for the utility functions.""" - - -import unittest -from typing import List - -from flwr.client.run_state import RunState -from flwr.client.typing import Bwd, FlowerCallable, Fwd, Layer -from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 - -from .utils import make_ffn - - -def make_mock_middleware(name: str, footprint: List[str]) -> Layer: - """Make a mock middleware layer.""" - - def middleware(fwd: Fwd, app: FlowerCallable) -> Bwd: - footprint.append(name) - fwd.task_ins.task_id += f"{name}" - bwd = app(fwd) - footprint.append(name) - bwd.task_res.task_id += f"{name}" - return bwd - - return middleware - - -def make_mock_app(name: str, footprint: List[str]) -> FlowerCallable: - """Make a mock app.""" - - def app(fwd: Fwd) -> Bwd: - footprint.append(name) - fwd.task_ins.task_id += f"{name}" - return Bwd(task_res=TaskRes(task_id=name), state=RunState({})) - - return app - - -class TestMakeApp(unittest.TestCase): - """Tests for the `make_app` function.""" - - def test_multiple_middlewares(self) -> None: - """Test if multiple middlewares are called in the correct order.""" - # Prepare - footprint: List[str] = [] - mock_app = make_mock_app("app", footprint) - mock_middleware_names = [f"middleware{i}" for i in range(1, 15)] - mock_middleware_layers = [ - make_mock_middleware(name, footprint) for name in mock_middleware_names - ] - task_ins = TaskIns() - - # Execute - wrapped_app = make_ffn(mock_app, mock_middleware_layers) - task_res = wrapped_app(Fwd(task_ins=task_ins, state=RunState({}))).task_res - - # Assert - trace = mock_middleware_names + ["app"] - self.assertEqual(footprint, trace + list(reversed(mock_middleware_names))) - # pylint: disable-next=no-member - self.assertEqual(task_ins.task_id, "".join(trace)) - self.assertEqual(task_res.task_id, "".join(reversed(trace))) - - def test_filter(self) -> None: - """Test if a middleware can filter incoming TaskIns.""" - # Prepare - footprint: List[str] = [] - mock_app = make_mock_app("app", footprint) - task_ins = TaskIns() - - def filter_layer(fwd: Fwd, _: FlowerCallable) -> Bwd: - footprint.append("filter") - fwd.task_ins.task_id += "filter" - # Skip calling app - return Bwd(task_res=TaskRes(task_id="filter"), state=RunState({})) - - # Execute - wrapped_app = make_ffn(mock_app, [filter_layer]) - task_res = wrapped_app(Fwd(task_ins=task_ins, state=RunState({}))).task_res - - # Assert - self.assertEqual(footprint, ["filter"]) - # pylint: disable-next=no-member - self.assertEqual(task_ins.task_id, "filter") - self.assertEqual(task_res.task_id, "filter") diff --git a/src/py/flwr/client/mod/__init__.py b/src/py/flwr/client/mod/__init__.py new file mode 100644 index 000000000000..69a7d76ce95f --- /dev/null +++ b/src/py/flwr/client/mod/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Mods.""" + + +from .centraldp_mods import adaptiveclipping_mod, fixedclipping_mod +from .localdp_mod import LocalDpMod +from .secure_aggregation import secagg_mod, secaggplus_mod +from .utils import make_ffn + +__all__ = [ + "adaptiveclipping_mod", + "fixedclipping_mod", + "LocalDpMod", + "make_ffn", + "secagg_mod", + "secaggplus_mod", +] diff --git a/src/py/flwr/client/mod/centraldp_mods.py b/src/py/flwr/client/mod/centraldp_mods.py new file mode 100644 index 000000000000..0c0134e0f876 --- /dev/null +++ b/src/py/flwr/client/mod/centraldp_mods.py @@ -0,0 +1,147 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Clipping modifiers for central DP with client-side clipping.""" + + +from flwr.client.typing import ClientAppCallable +from flwr.common import ndarrays_to_parameters, parameters_to_ndarrays +from flwr.common import recordset_compat as compat +from flwr.common.constant import MessageType +from flwr.common.context import Context +from flwr.common.differential_privacy import ( + compute_adaptive_clip_model_update, + compute_clip_model_update, +) +from flwr.common.differential_privacy_constants import KEY_CLIPPING_NORM, KEY_NORM_BIT +from flwr.common.message import Message + + +def fixedclipping_mod( + msg: Message, ctxt: Context, call_next: ClientAppCallable +) -> Message: + """Client-side fixed clipping modifier. + + This mod needs to be used with the DifferentialPrivacyClientSideFixedClipping + server-side strategy wrapper. + + The wrapper sends the clipping_norm value to the client. + + This mod clips the client model updates before sending them to the server. + + It operates on messages of type `MessageType.TRAIN`. + + Notes + ----- + Consider the order of mods when using multiple. + + Typically, fixedclipping_mod should be the last to operate on params. + """ + if msg.metadata.message_type != MessageType.TRAIN: + return call_next(msg, ctxt) + fit_ins = compat.recordset_to_fitins(msg.content, keep_input=True) + if KEY_CLIPPING_NORM not in fit_ins.config: + raise KeyError( + f"The {KEY_CLIPPING_NORM} value is not supplied by the " + f"DifferentialPrivacyClientSideFixedClipping wrapper at" + f" the server side." + ) + + clipping_norm = float(fit_ins.config[KEY_CLIPPING_NORM]) + server_to_client_params = parameters_to_ndarrays(fit_ins.parameters) + + # Call inner app + out_msg = call_next(msg, ctxt) + + # Check if the msg has error + if out_msg.has_error(): + return out_msg + + fit_res = compat.recordset_to_fitres(out_msg.content, keep_input=True) + + client_to_server_params = parameters_to_ndarrays(fit_res.parameters) + + # Clip the client update + compute_clip_model_update( + client_to_server_params, + server_to_client_params, + clipping_norm, + ) + + fit_res.parameters = ndarrays_to_parameters(client_to_server_params) + out_msg.content = compat.fitres_to_recordset(fit_res, keep_input=True) + return out_msg + + +def adaptiveclipping_mod( + msg: Message, ctxt: Context, call_next: ClientAppCallable +) -> Message: + """Client-side adaptive clipping modifier. + + This mod needs to be used with the DifferentialPrivacyClientSideAdaptiveClipping + server-side strategy wrapper. + + The wrapper sends the clipping_norm value to the client. + + This mod clips the client model updates before sending them to the server. + + It also sends KEY_NORM_BIT to the server for computing the new clipping value. + + It operates on messages of type `MessageType.TRAIN`. + + Notes + ----- + Consider the order of mods when using multiple. + + Typically, adaptiveclipping_mod should be the last to operate on params. + """ + if msg.metadata.message_type != MessageType.TRAIN: + return call_next(msg, ctxt) + + fit_ins = compat.recordset_to_fitins(msg.content, keep_input=True) + + if KEY_CLIPPING_NORM not in fit_ins.config: + raise KeyError( + f"The {KEY_CLIPPING_NORM} value is not supplied by the " + f"DifferentialPrivacyClientSideFixedClipping wrapper at" + f" the server side." + ) + if not isinstance(fit_ins.config[KEY_CLIPPING_NORM], float): + raise ValueError(f"{KEY_CLIPPING_NORM} should be a float value.") + clipping_norm = float(fit_ins.config[KEY_CLIPPING_NORM]) + server_to_client_params = parameters_to_ndarrays(fit_ins.parameters) + + # Call inner app + out_msg = call_next(msg, ctxt) + + # Check if the msg has error + if out_msg.has_error(): + return out_msg + + fit_res = compat.recordset_to_fitres(out_msg.content, keep_input=True) + + client_to_server_params = parameters_to_ndarrays(fit_res.parameters) + + # Clip the client update + norm_bit = compute_adaptive_clip_model_update( + client_to_server_params, + server_to_client_params, + clipping_norm, + ) + + fit_res.parameters = ndarrays_to_parameters(client_to_server_params) + + fit_res.metrics[KEY_NORM_BIT] = norm_bit + out_msg.content = compat.fitres_to_recordset(fit_res, keep_input=True) + return out_msg diff --git a/src/py/flwr/client/mod/localdp_mod.py b/src/py/flwr/client/mod/localdp_mod.py new file mode 100644 index 000000000000..5f62c9e44800 --- /dev/null +++ b/src/py/flwr/client/mod/localdp_mod.py @@ -0,0 +1,134 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Local DP modifier.""" + + +from flwr.client.typing import ClientAppCallable +from flwr.common import ndarrays_to_parameters, parameters_to_ndarrays +from flwr.common import recordset_compat as compat +from flwr.common.constant import MessageType +from flwr.common.context import Context +from flwr.common.differential_privacy import ( + add_localdp_gaussian_noise_to_params, + compute_clip_model_update, +) +from flwr.common.message import Message + + +class LocalDpMod: + """Modifier for local differential privacy. + + This mod clips the client model updates and + adds noise to the params before sending them to the server. + + It operates on messages of type `MessageType.TRAIN`. + + Parameters + ---------- + clipping_norm : float + The value of the clipping norm. + sensitivity : float + The sensitivity of the client model. + epsilon : float + The privacy budget. + Smaller value of epsilon indicates a higher level of privacy protection. + delta : float + The failure probability. + The probability that the privacy mechanism + fails to provide the desired level of privacy. + A smaller value of delta indicates a stricter privacy guarantee. + + Examples + -------- + Create an instance of the local DP mod and add it to the client-side mods: + + >>> local_dp_mod = LocalDpMod( ... ) + >>> app = fl.client.ClientApp( + >>> client_fn=client_fn, mods=[local_dp_mod] + >>> ) + """ + + def __init__( + self, clipping_norm: float, sensitivity: float, epsilon: float, delta: float + ) -> None: + if clipping_norm <= 0: + raise ValueError("The clipping norm should be a positive value.") + + if sensitivity < 0: + raise ValueError("The sensitivity should be a non-negative value.") + + if epsilon < 0: + raise ValueError("Epsilon should be a non-negative value.") + + if delta < 0: + raise ValueError("Delta should be a non-negative value.") + + self.clipping_norm = clipping_norm + self.sensitivity = sensitivity + self.epsilon = epsilon + self.delta = delta + + def __call__( + self, msg: Message, ctxt: Context, call_next: ClientAppCallable + ) -> Message: + """Perform local DP on the client model parameters. + + Parameters + ---------- + msg : Message + The message received from the server. + ctxt : Context + The context of the client. + call_next : ClientAppCallable + The callable to call the next middleware in the chain. + + Returns + ------- + Message + The modified message to be sent back to the server. + """ + if msg.metadata.message_type != MessageType.TRAIN: + return call_next(msg, ctxt) + + fit_ins = compat.recordset_to_fitins(msg.content, keep_input=True) + server_to_client_params = parameters_to_ndarrays(fit_ins.parameters) + + # Call inner app + out_msg = call_next(msg, ctxt) + + # Check if the msg has error + if out_msg.has_error(): + return out_msg + + fit_res = compat.recordset_to_fitres(out_msg.content, keep_input=True) + + client_to_server_params = parameters_to_ndarrays(fit_res.parameters) + + # Clip the client update + compute_clip_model_update( + client_to_server_params, + server_to_client_params, + self.clipping_norm, + ) + + fit_res.parameters = ndarrays_to_parameters(client_to_server_params) + + # Add noise to model params + add_localdp_gaussian_noise_to_params( + fit_res.parameters, self.sensitivity, self.epsilon, self.delta + ) + + out_msg.content = compat.fitres_to_recordset(fit_res, keep_input=True) + return out_msg diff --git a/src/py/flwr/client/middleware/__init__.py b/src/py/flwr/client/mod/secure_aggregation/__init__.py similarity index 82% rename from src/py/flwr/client/middleware/__init__.py rename to src/py/flwr/client/mod/secure_aggregation/__init__.py index 58b31296fbbe..8892d8c03935 100644 --- a/src/py/flwr/client/middleware/__init__.py +++ b/src/py/flwr/client/mod/secure_aggregation/__init__.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Middleware layers.""" +"""Secure Aggregation mods.""" -from .utils import make_ffn +from .secagg_mod import secagg_mod +from .secaggplus_mod import secaggplus_mod __all__ = [ - "make_ffn", + "secagg_mod", + "secaggplus_mod", ] diff --git a/src/py/flwr/client/mod/secure_aggregation/secagg_mod.py b/src/py/flwr/client/mod/secure_aggregation/secagg_mod.py new file mode 100644 index 000000000000..d87af59a4e6e --- /dev/null +++ b/src/py/flwr/client/mod/secure_aggregation/secagg_mod.py @@ -0,0 +1,30 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Modifier for the SecAgg protocol.""" + + +from flwr.client.typing import ClientAppCallable +from flwr.common import Context, Message + +from .secaggplus_mod import secaggplus_mod + + +def secagg_mod( + msg: Message, + ctxt: Context, + call_next: ClientAppCallable, +) -> Message: + """Handle incoming message and return results, following the SecAgg protocol.""" + return secaggplus_mod(msg, ctxt, call_next) diff --git a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py new file mode 100644 index 000000000000..ed0f8f4fd7b5 --- /dev/null +++ b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py @@ -0,0 +1,527 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Message handler for the SecAgg+ protocol.""" + + +import os +from dataclasses import dataclass, field +from logging import INFO, WARNING +from typing import Any, Callable, Dict, List, Tuple, cast + +from flwr.client.typing import ClientAppCallable +from flwr.common import ( + ConfigsRecord, + Context, + Message, + RecordSet, + ndarray_to_bytes, + parameters_to_ndarrays, +) +from flwr.common import recordset_compat as compat +from flwr.common.constant import MessageType +from flwr.common.logger import log +from flwr.common.secure_aggregation.crypto.shamir import create_shares +from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( + bytes_to_private_key, + bytes_to_public_key, + decrypt, + encrypt, + generate_key_pairs, + generate_shared_key, + private_key_to_bytes, + public_key_to_bytes, +) +from flwr.common.secure_aggregation.ndarrays_arithmetic import ( + factor_combine, + parameters_addition, + parameters_mod, + parameters_multiply, + parameters_subtraction, +) +from flwr.common.secure_aggregation.quantization import quantize +from flwr.common.secure_aggregation.secaggplus_constants import ( + RECORD_KEY_CONFIGS, + RECORD_KEY_STATE, + Key, + Stage, +) +from flwr.common.secure_aggregation.secaggplus_utils import ( + pseudo_rand_gen, + share_keys_plaintext_concat, + share_keys_plaintext_separate, +) +from flwr.common.typing import ConfigsRecordValues, FitRes + + +@dataclass +# pylint: disable-next=too-many-instance-attributes +class SecAggPlusState: + """State of the SecAgg+ protocol.""" + + current_stage: str = Stage.UNMASK + + nid: int = 0 + sample_num: int = 0 + share_num: int = 0 + threshold: int = 0 + clipping_range: float = 0.0 + target_range: int = 0 + mod_range: int = 0 + max_weight: float = 0.0 + + # Secret key (sk) and public key (pk) + sk1: bytes = b"" + pk1: bytes = b"" + sk2: bytes = b"" + pk2: bytes = b"" + + # Random seed for generating the private mask + rd_seed: bytes = b"" + + rd_seed_share_dict: Dict[int, bytes] = field(default_factory=dict) + sk1_share_dict: Dict[int, bytes] = field(default_factory=dict) + # The dict of the shared secrets from sk2 + ss2_dict: Dict[int, bytes] = field(default_factory=dict) + public_keys_dict: Dict[int, Tuple[bytes, bytes]] = field(default_factory=dict) + + def __init__(self, **kwargs: ConfigsRecordValues) -> None: + for k, v in kwargs.items(): + if k.endswith(":V"): + continue + new_v: Any = v + if k.endswith(":K"): + k = k[:-2] + keys = cast(List[int], v) + values = cast(List[bytes], kwargs[f"{k}:V"]) + if len(values) > len(keys): + updated_values = [ + tuple(values[i : i + 2]) for i in range(0, len(values), 2) + ] + new_v = dict(zip(keys, updated_values)) + else: + new_v = dict(zip(keys, values)) + self.__setattr__(k, new_v) + + def to_dict(self) -> Dict[str, ConfigsRecordValues]: + """Convert the state to a dictionary.""" + ret = vars(self) + for k in list(ret.keys()): + if isinstance(ret[k], dict): + # Replace dict with two lists + v = cast(Dict[str, Any], ret.pop(k)) + ret[f"{k}:K"] = list(v.keys()) + if k == "public_keys_dict": + v_list: List[bytes] = [] + for b1_b2 in cast(List[Tuple[bytes, bytes]], v.values()): + v_list.extend(b1_b2) + ret[f"{k}:V"] = v_list + else: + ret[f"{k}:V"] = list(v.values()) + return ret + + +def _get_fit_fn( + msg: Message, ctxt: Context, call_next: ClientAppCallable +) -> Callable[[], FitRes]: + """Get the fit function.""" + + def fit() -> FitRes: + out_msg = call_next(msg, ctxt) + return compat.recordset_to_fitres(out_msg.content, keep_input=False) + + return fit + + +def secaggplus_mod( + msg: Message, + ctxt: Context, + call_next: ClientAppCallable, +) -> Message: + """Handle incoming message and return results, following the SecAgg+ protocol.""" + # Ignore non-fit messages + if msg.metadata.message_type != MessageType.TRAIN: + return call_next(msg, ctxt) + + # Retrieve local state + if RECORD_KEY_STATE not in ctxt.state.configs_records: + ctxt.state.configs_records[RECORD_KEY_STATE] = ConfigsRecord({}) + state_dict = ctxt.state.configs_records[RECORD_KEY_STATE] + state = SecAggPlusState(**state_dict) + + # Retrieve incoming configs + configs = msg.content.configs_records[RECORD_KEY_CONFIGS] + + # Check the validity of the next stage + check_stage(state.current_stage, configs) + + # Update the current stage + state.current_stage = cast(str, configs.pop(Key.STAGE)) + + # Check the validity of the configs based on the current stage + check_configs(state.current_stage, configs) + + # Execute + if state.current_stage == Stage.SETUP: + state.nid = msg.metadata.dst_node_id + res = _setup(state, configs) + elif state.current_stage == Stage.SHARE_KEYS: + res = _share_keys(state, configs) + elif state.current_stage == Stage.COLLECT_MASKED_VECTORS: + fit = _get_fit_fn(msg, ctxt, call_next) + res = _collect_masked_vectors(state, configs, fit) + elif state.current_stage == Stage.UNMASK: + res = _unmask(state, configs) + else: + raise ValueError(f"Unknown secagg stage: {state.current_stage}") + + # Save state + ctxt.state.configs_records[RECORD_KEY_STATE] = ConfigsRecord(state.to_dict()) + + # Return message + content = RecordSet(configs_records={RECORD_KEY_CONFIGS: ConfigsRecord(res, False)}) + return msg.create_reply(content, ttl="") + + +def check_stage(current_stage: str, configs: ConfigsRecord) -> None: + """Check the validity of the next stage.""" + # Check the existence of Config.STAGE + if Key.STAGE not in configs: + raise KeyError( + f"The required key '{Key.STAGE}' is missing from the ConfigsRecord." + ) + + # Check the value type of the Config.STAGE + next_stage = configs[Key.STAGE] + if not isinstance(next_stage, str): + raise TypeError( + f"The value for the key '{Key.STAGE}' must be of type {str}, " + f"but got {type(next_stage)} instead." + ) + + # Check the validity of the next stage + if next_stage == Stage.SETUP: + if current_stage != Stage.UNMASK: + log(WARNING, "Restart from the setup stage") + # If stage is not "setup", + # the stage from configs should be the expected next stage + else: + stages = Stage.all() + expected_next_stage = stages[(stages.index(current_stage) + 1) % len(stages)] + if next_stage != expected_next_stage: + raise ValueError( + "Abort secure aggregation: " + f"expect {expected_next_stage} stage, but receive {next_stage} stage" + ) + + +# pylint: disable-next=too-many-branches +def check_configs(stage: str, configs: ConfigsRecord) -> None: + """Check the validity of the configs.""" + # Check configs for the setup stage + if stage == Stage.SETUP: + key_type_pairs = [ + (Key.SAMPLE_NUMBER, int), + (Key.SHARE_NUMBER, int), + (Key.THRESHOLD, int), + (Key.CLIPPING_RANGE, float), + (Key.TARGET_RANGE, int), + (Key.MOD_RANGE, int), + ] + for key, expected_type in key_type_pairs: + if key not in configs: + raise KeyError( + f"Stage {Stage.SETUP}: the required key '{key}' is " + "missing from the ConfigsRecord." + ) + # Bool is a subclass of int in Python, + # so `isinstance(v, int)` will return True even if v is a boolean. + # pylint: disable-next=unidiomatic-typecheck + if type(configs[key]) is not expected_type: + raise TypeError( + f"Stage {Stage.SETUP}: The value for the key '{key}' " + f"must be of type {expected_type}, " + f"but got {type(configs[key])} instead." + ) + elif stage == Stage.SHARE_KEYS: + for key, value in configs.items(): + if ( + not isinstance(value, list) + or len(value) != 2 + or not isinstance(value[0], bytes) + or not isinstance(value[1], bytes) + ): + raise TypeError( + f"Stage {Stage.SHARE_KEYS}: " + f"the value for the key '{key}' must be a list of two bytes." + ) + elif stage == Stage.COLLECT_MASKED_VECTORS: + key_type_pairs = [ + (Key.CIPHERTEXT_LIST, bytes), + (Key.SOURCE_LIST, int), + ] + for key, expected_type in key_type_pairs: + if key not in configs: + raise KeyError( + f"Stage {Stage.COLLECT_MASKED_VECTORS}: " + f"the required key '{key}' is " + "missing from the ConfigsRecord." + ) + if not isinstance(configs[key], list) or any( + elm + for elm in cast(List[Any], configs[key]) + # pylint: disable-next=unidiomatic-typecheck + if type(elm) is not expected_type + ): + raise TypeError( + f"Stage {Stage.COLLECT_MASKED_VECTORS}: " + f"the value for the key '{key}' " + f"must be of type List[{expected_type.__name__}]" + ) + elif stage == Stage.UNMASK: + key_type_pairs = [ + (Key.ACTIVE_NODE_ID_LIST, int), + (Key.DEAD_NODE_ID_LIST, int), + ] + for key, expected_type in key_type_pairs: + if key not in configs: + raise KeyError( + f"Stage {Stage.UNMASK}: " + f"the required key '{key}' is " + "missing from the ConfigsRecord." + ) + if not isinstance(configs[key], list) or any( + elm + for elm in cast(List[Any], configs[key]) + # pylint: disable-next=unidiomatic-typecheck + if type(elm) is not expected_type + ): + raise TypeError( + f"Stage {Stage.UNMASK}: " + f"the value for the key '{key}' " + f"must be of type List[{expected_type.__name__}]" + ) + else: + raise ValueError(f"Unknown secagg stage: {stage}") + + +def _setup( + state: SecAggPlusState, configs: ConfigsRecord +) -> Dict[str, ConfigsRecordValues]: + # Assigning parameter values to object fields + sec_agg_param_dict = configs + state.sample_num = cast(int, sec_agg_param_dict[Key.SAMPLE_NUMBER]) + log(INFO, "Node %d: starting stage 0...", state.nid) + + state.share_num = cast(int, sec_agg_param_dict[Key.SHARE_NUMBER]) + state.threshold = cast(int, sec_agg_param_dict[Key.THRESHOLD]) + state.clipping_range = cast(float, sec_agg_param_dict[Key.CLIPPING_RANGE]) + state.target_range = cast(int, sec_agg_param_dict[Key.TARGET_RANGE]) + state.mod_range = cast(int, sec_agg_param_dict[Key.MOD_RANGE]) + state.max_weight = cast(float, sec_agg_param_dict[Key.MAX_WEIGHT]) + + # Dictionaries containing node IDs as keys + # and their respective secret shares as values. + state.rd_seed_share_dict = {} + state.sk1_share_dict = {} + # Dictionary containing node IDs as keys + # and their respective shared secrets (with this client) as values. + state.ss2_dict = {} + + # Create 2 sets private public key pairs + # One for creating pairwise masks + # One for encrypting message to distribute shares + sk1, pk1 = generate_key_pairs() + sk2, pk2 = generate_key_pairs() + + state.sk1, state.pk1 = private_key_to_bytes(sk1), public_key_to_bytes(pk1) + state.sk2, state.pk2 = private_key_to_bytes(sk2), public_key_to_bytes(pk2) + log(INFO, "Node %d: stage 0 completes. uploading public keys...", state.nid) + return {Key.PUBLIC_KEY_1: state.pk1, Key.PUBLIC_KEY_2: state.pk2} + + +# pylint: disable-next=too-many-locals +def _share_keys( + state: SecAggPlusState, configs: ConfigsRecord +) -> Dict[str, ConfigsRecordValues]: + named_bytes_tuples = cast(Dict[str, Tuple[bytes, bytes]], configs) + key_dict = {int(sid): (pk1, pk2) for sid, (pk1, pk2) in named_bytes_tuples.items()} + log(INFO, "Node %d: starting stage 1...", state.nid) + state.public_keys_dict = key_dict + + # Check if the size is larger than threshold + if len(state.public_keys_dict) < state.threshold: + raise ValueError("Available neighbours number smaller than threshold") + + # Check if all public keys are unique + pk_list: List[bytes] = [] + for pk1, pk2 in state.public_keys_dict.values(): + pk_list.append(pk1) + pk_list.append(pk2) + if len(set(pk_list)) != len(pk_list): + raise ValueError("Some public keys are identical") + + # Check if public keys of this client are correct in the dictionary + if ( + state.public_keys_dict[state.nid][0] != state.pk1 + or state.public_keys_dict[state.nid][1] != state.pk2 + ): + raise ValueError( + "Own public keys are displayed in dict incorrectly, should not happen!" + ) + + # Generate the private mask seed + state.rd_seed = os.urandom(32) + + # Create shares for the private mask seed and the first private key + b_shares = create_shares(state.rd_seed, state.threshold, state.share_num) + sk1_shares = create_shares(state.sk1, state.threshold, state.share_num) + + srcs, dsts, ciphertexts = [], [], [] + + # Distribute shares + for idx, (nid, (_, pk2)) in enumerate(state.public_keys_dict.items()): + if nid == state.nid: + state.rd_seed_share_dict[state.nid] = b_shares[idx] + state.sk1_share_dict[state.nid] = sk1_shares[idx] + else: + shared_key = generate_shared_key( + bytes_to_private_key(state.sk2), + bytes_to_public_key(pk2), + ) + state.ss2_dict[nid] = shared_key + plaintext = share_keys_plaintext_concat( + state.nid, nid, b_shares[idx], sk1_shares[idx] + ) + ciphertext = encrypt(shared_key, plaintext) + srcs.append(state.nid) + dsts.append(nid) + ciphertexts.append(ciphertext) + + log(INFO, "Node %d: stage 1 completes. uploading key shares...", state.nid) + return {Key.DESTINATION_LIST: dsts, Key.CIPHERTEXT_LIST: ciphertexts} + + +# pylint: disable-next=too-many-locals +def _collect_masked_vectors( + state: SecAggPlusState, + configs: ConfigsRecord, + fit: Callable[[], FitRes], +) -> Dict[str, ConfigsRecordValues]: + log(INFO, "Node %d: starting stage 2...", state.nid) + available_clients: List[int] = [] + ciphertexts = cast(List[bytes], configs[Key.CIPHERTEXT_LIST]) + srcs = cast(List[int], configs[Key.SOURCE_LIST]) + if len(ciphertexts) + 1 < state.threshold: + raise ValueError("Not enough available neighbour clients.") + + # Decrypt ciphertexts, verify their sources, and store shares. + for src, ciphertext in zip(srcs, ciphertexts): + shared_key = state.ss2_dict[src] + plaintext = decrypt(shared_key, ciphertext) + actual_src, dst, rd_seed_share, sk1_share = share_keys_plaintext_separate( + plaintext + ) + available_clients.append(src) + if src != actual_src: + raise ValueError( + f"Node {state.nid}: received ciphertext " + f"from {actual_src} instead of {src}." + ) + if dst != state.nid: + raise ValueError( + f"Node {state.nid}: received an encrypted message" + f"for Node {dst} from Node {src}." + ) + state.rd_seed_share_dict[src] = rd_seed_share + state.sk1_share_dict[src] = sk1_share + + # Fit client + fit_res = fit() + if len(fit_res.metrics) > 0: + log( + WARNING, + "The metrics in FitRes will not be preserved or sent to the server.", + ) + ratio = fit_res.num_examples / state.max_weight + if ratio > 1: + log( + WARNING, + "Potential overflow warning: the provided weight (%s) exceeds the specified" + " max_weight (%s). This may lead to overflow issues.", + fit_res.num_examples, + state.max_weight, + ) + q_ratio = round(ratio * state.target_range) + dq_ratio = q_ratio / state.target_range + + parameters = parameters_to_ndarrays(fit_res.parameters) + parameters = parameters_multiply(parameters, dq_ratio) + + # Quantize parameter update (vector) + quantized_parameters = quantize( + parameters, state.clipping_range, state.target_range + ) + + quantized_parameters = factor_combine(q_ratio, quantized_parameters) + + dimensions_list: List[Tuple[int, ...]] = [a.shape for a in quantized_parameters] + + # Add private mask + private_mask = pseudo_rand_gen(state.rd_seed, state.mod_range, dimensions_list) + quantized_parameters = parameters_addition(quantized_parameters, private_mask) + + for node_id in available_clients: + # Add pairwise masks + shared_key = generate_shared_key( + bytes_to_private_key(state.sk1), + bytes_to_public_key(state.public_keys_dict[node_id][0]), + ) + pairwise_mask = pseudo_rand_gen(shared_key, state.mod_range, dimensions_list) + if state.nid > node_id: + quantized_parameters = parameters_addition( + quantized_parameters, pairwise_mask + ) + else: + quantized_parameters = parameters_subtraction( + quantized_parameters, pairwise_mask + ) + + # Take mod of final weight update vector and return to server + quantized_parameters = parameters_mod(quantized_parameters, state.mod_range) + log(INFO, "Node %d: stage 2 completed, uploading masked parameters...", state.nid) + return { + Key.MASKED_PARAMETERS: [ndarray_to_bytes(arr) for arr in quantized_parameters] + } + + +def _unmask( + state: SecAggPlusState, configs: ConfigsRecord +) -> Dict[str, ConfigsRecordValues]: + log(INFO, "Node %d: starting stage 3...", state.nid) + + active_nids = cast(List[int], configs[Key.ACTIVE_NODE_ID_LIST]) + dead_nids = cast(List[int], configs[Key.DEAD_NODE_ID_LIST]) + # Send private mask seed share for every avaliable client (including itself) + # Send first private key share for building pairwise mask for every dropped client + if len(active_nids) < state.threshold: + raise ValueError("Available neighbours number smaller than threshold") + + all_nids, shares = [], [] + all_nids = active_nids + dead_nids + shares += [state.rd_seed_share_dict[nid] for nid in active_nids] + shares += [state.sk1_share_dict[nid] for nid in dead_nids] + + log(INFO, "Node %d: stage 3 completes. uploading key shares...", state.nid) + return {Key.NODE_ID_LIST: all_nids, Key.SHARE_LIST: shares} diff --git a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py new file mode 100644 index 000000000000..d72d8b414f65 --- /dev/null +++ b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py @@ -0,0 +1,316 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The SecAgg+ protocol handler tests.""" + +import unittest +from itertools import product +from typing import Callable, Dict, List + +from flwr.client.mod import make_ffn +from flwr.common import ConfigsRecord, Context, Message, Metadata, RecordSet +from flwr.common.constant import MessageType +from flwr.common.secure_aggregation.secaggplus_constants import ( + RECORD_KEY_CONFIGS, + RECORD_KEY_STATE, + Key, + Stage, +) +from flwr.common.typing import ConfigsRecordValues + +from .secaggplus_mod import SecAggPlusState, check_configs, secaggplus_mod + + +def get_test_handler( + ctxt: Context, +) -> Callable[[Dict[str, ConfigsRecordValues]], ConfigsRecord]: + """.""" + + def empty_ffn(_msg: Message, _2: Context) -> Message: + return _msg.create_reply(RecordSet(), ttl="") + + app = make_ffn(empty_ffn, [secaggplus_mod]) + + def func(configs: Dict[str, ConfigsRecordValues]) -> ConfigsRecord: + in_msg = Message( + metadata=Metadata( + run_id=0, + message_id="", + src_node_id=0, + dst_node_id=123, + reply_to_message="", + group_id="", + ttl="", + message_type=MessageType.TRAIN, + ), + content=RecordSet( + configs_records={RECORD_KEY_CONFIGS: ConfigsRecord(configs)} + ), + ) + out_msg = app(in_msg, ctxt) + return out_msg.content.configs_records[RECORD_KEY_CONFIGS] + + return func + + +def _make_ctxt() -> Context: + cfg = ConfigsRecord(SecAggPlusState().to_dict()) + return Context(RecordSet(configs_records={RECORD_KEY_STATE: cfg})) + + +def _make_set_state_fn( + ctxt: Context, +) -> Callable[[str], None]: + def set_stage(stage: str) -> None: + state_dict = ctxt.state.configs_records[RECORD_KEY_STATE] + state = SecAggPlusState(**state_dict) + state.current_stage = stage + ctxt.state.configs_records[RECORD_KEY_STATE] = ConfigsRecord(state.to_dict()) + + return set_stage + + +class TestSecAggPlusHandler(unittest.TestCase): + """Test the SecAgg+ protocol handler.""" + + def test_stage_transition(self) -> None: + """Test stage transition.""" + ctxt = _make_ctxt() + handler = get_test_handler(ctxt) + set_stage = _make_set_state_fn(ctxt) + + assert Stage.all() == ( + Stage.SETUP, + Stage.SHARE_KEYS, + Stage.COLLECT_MASKED_VECTORS, + Stage.UNMASK, + ) + + valid_transitions = { + # From one stage to the next stage + (Stage.UNMASK, Stage.SETUP), + (Stage.SETUP, Stage.SHARE_KEYS), + (Stage.SHARE_KEYS, Stage.COLLECT_MASKED_VECTORS), + (Stage.COLLECT_MASKED_VECTORS, Stage.UNMASK), + # From any stage to the initial stage + # Such transitions will log a warning. + (Stage.SETUP, Stage.SETUP), + (Stage.SHARE_KEYS, Stage.SETUP), + (Stage.COLLECT_MASKED_VECTORS, Stage.SETUP), + } + + invalid_transitions = set(product(Stage.all(), Stage.all())).difference( + valid_transitions + ) + + # Test valid transitions + # If the next stage is valid, the function should update the current stage + # and then raise KeyError or other exceptions when trying to execute SA. + for current_stage, next_stage in valid_transitions: + set_stage(current_stage) + + with self.assertRaises(KeyError): + handler({Key.STAGE: next_stage}) + + # Test invalid transitions + # If the next stage is invalid, the function should raise ValueError + for current_stage, next_stage in invalid_transitions: + set_stage(current_stage) + + with self.assertRaises(ValueError): + handler({Key.STAGE: next_stage}) + + def test_stage_setup_check(self) -> None: + """Test content checking for the setup stage.""" + ctxt = _make_ctxt() + handler = get_test_handler(ctxt) + set_stage = _make_set_state_fn(ctxt) + + valid_key_type_pairs = [ + (Key.SAMPLE_NUMBER, int), + (Key.SHARE_NUMBER, int), + (Key.THRESHOLD, int), + (Key.CLIPPING_RANGE, float), + (Key.TARGET_RANGE, int), + (Key.MOD_RANGE, int), + ] + + type_to_test_value: Dict[type, ConfigsRecordValues] = { + int: 10, + bool: True, + float: 1.0, + str: "test", + bytes: b"test", + } + + valid_configs: Dict[str, ConfigsRecordValues] = { + key: type_to_test_value[value_type] + for key, value_type in valid_key_type_pairs + } + + # Test valid configs + try: + check_configs(Stage.SETUP, ConfigsRecord(valid_configs)) + # pylint: disable-next=broad-except + except Exception as exc: + self.fail(f"check_configs() raised {type(exc)} unexpectedly!") + + # Set the stage + valid_configs[Key.STAGE] = Stage.SETUP + + # Test invalid configs + for key, value_type in valid_key_type_pairs: + invalid_configs = valid_configs.copy() + + # Test wrong value type for the key + for other_type, other_value in type_to_test_value.items(): + if other_type == value_type: + continue + invalid_configs[key] = other_value + + set_stage(Stage.UNMASK) + with self.assertRaises(TypeError): + handler(invalid_configs.copy()) + + # Test missing key + invalid_configs.pop(key) + + set_stage(Stage.UNMASK) + with self.assertRaises(KeyError): + handler(invalid_configs.copy()) + + def test_stage_share_keys_check(self) -> None: + """Test content checking for the share keys stage.""" + ctxt = _make_ctxt() + handler = get_test_handler(ctxt) + set_stage = _make_set_state_fn(ctxt) + + valid_configs: Dict[str, ConfigsRecordValues] = { + "1": [b"public key 1", b"public key 2"], + "2": [b"public key 1", b"public key 2"], + "3": [b"public key 1", b"public key 2"], + } + + # Test valid configs + try: + check_configs(Stage.SHARE_KEYS, ConfigsRecord(valid_configs)) + # pylint: disable-next=broad-except + except Exception as exc: + self.fail(f"check_configs() raised {type(exc)} unexpectedly!") + + # Set the stage + valid_configs[Key.STAGE] = Stage.SHARE_KEYS + + # Test invalid configs + invalid_values: List[ConfigsRecordValues] = [ + b"public key 1", + [b"public key 1"], + [b"public key 1", b"public key 2", b"public key 3"], + ] + + for value in invalid_values: + invalid_configs = valid_configs.copy() + invalid_configs["1"] = value + + set_stage(Stage.SETUP) + with self.assertRaises(TypeError): + handler(invalid_configs.copy()) + + def test_stage_collect_masked_vectors_check(self) -> None: + """Test content checking for the collect masked vectors stage.""" + ctxt = _make_ctxt() + handler = get_test_handler(ctxt) + set_stage = _make_set_state_fn(ctxt) + + valid_configs: Dict[str, ConfigsRecordValues] = { + Key.CIPHERTEXT_LIST: [b"ctxt!", b"ctxt@", b"ctxt#", b"ctxt?"], + Key.SOURCE_LIST: [32, 51324, 32324123, -3], + } + + # Test valid configs + try: + check_configs(Stage.COLLECT_MASKED_VECTORS, ConfigsRecord(valid_configs)) + # pylint: disable-next=broad-except + except Exception as exc: + self.fail(f"check_configs() raised {type(exc)} unexpectedly!") + + # Set the stage + valid_configs[Key.STAGE] = Stage.COLLECT_MASKED_VECTORS + + # Test invalid configs + # Test missing keys + for key in list(valid_configs.keys()): + if key == Key.STAGE: + continue + invalid_configs = valid_configs.copy() + invalid_configs.pop(key) + + set_stage(Stage.SHARE_KEYS) + with self.assertRaises(KeyError): + handler(invalid_configs) + + # Test wrong value type for the key + for key in valid_configs: + if key == Key.STAGE: + continue + invalid_configs = valid_configs.copy() + invalid_configs[key] = [3.1415926] + + set_stage(Stage.SHARE_KEYS) + with self.assertRaises(TypeError): + handler(invalid_configs) + + def test_stage_unmask_check(self) -> None: + """Test content checking for the unmasking stage.""" + ctxt = _make_ctxt() + handler = get_test_handler(ctxt) + set_stage = _make_set_state_fn(ctxt) + + valid_configs: Dict[str, ConfigsRecordValues] = { + Key.ACTIVE_NODE_ID_LIST: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + Key.DEAD_NODE_ID_LIST: [32, 51324, 32324123, -3], + } + + # Test valid configs + try: + check_configs(Stage.UNMASK, ConfigsRecord(valid_configs)) + # pylint: disable-next=broad-except + except Exception as exc: + self.fail(f"check_configs() raised {type(exc)} unexpectedly!") + + # Set the stage + valid_configs[Key.STAGE] = Stage.UNMASK + + # Test invalid configs + # Test missing keys + for key in list(valid_configs.keys()): + if key == Key.STAGE: + continue + invalid_configs = valid_configs.copy() + invalid_configs.pop(key) + + set_stage(Stage.COLLECT_MASKED_VECTORS) + with self.assertRaises(KeyError): + handler(invalid_configs) + + # Test wrong value type for the key + for key in valid_configs: + if key == Key.STAGE: + continue + invalid_configs = valid_configs.copy() + invalid_configs[key] = [True, False, True, False] + + set_stage(Stage.COLLECT_MASKED_VECTORS) + with self.assertRaises(TypeError): + handler(invalid_configs) diff --git a/src/py/flwr/client/middleware/utils.py b/src/py/flwr/client/mod/utils.py similarity index 62% rename from src/py/flwr/client/middleware/utils.py rename to src/py/flwr/client/mod/utils.py index d93132403c1e..4c3c32944f01 100644 --- a/src/py/flwr/client/middleware/utils.py +++ b/src/py/flwr/client/mod/utils.py @@ -12,24 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Utility functions for middleware layers.""" +"""Utility functions for mods.""" from typing import List -from flwr.client.typing import Bwd, FlowerCallable, Fwd, Layer +from flwr.client.typing import ClientAppCallable, Mod +from flwr.common import Context, Message -def make_ffn(ffn: FlowerCallable, layers: List[Layer]) -> FlowerCallable: +def make_ffn(ffn: ClientAppCallable, mods: List[Mod]) -> ClientAppCallable: """.""" - def wrap_ffn(_ffn: FlowerCallable, _layer: Layer) -> FlowerCallable: - def new_ffn(fwd: Fwd) -> Bwd: - return _layer(fwd, _ffn) + def wrap_ffn(_ffn: ClientAppCallable, _mod: Mod) -> ClientAppCallable: + def new_ffn(message: Message, context: Context) -> Message: + return _mod(message, context, _ffn) return new_ffn - for layer in reversed(layers): - ffn = wrap_ffn(ffn, layer) + for mod in reversed(mods): + ffn = wrap_ffn(ffn, mod) return ffn diff --git a/src/py/flwr/client/mod/utils_test.py b/src/py/flwr/client/mod/utils_test.py new file mode 100644 index 000000000000..e588b8b53b3b --- /dev/null +++ b/src/py/flwr/client/mod/utils_test.py @@ -0,0 +1,154 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the utility functions.""" + + +import unittest +from typing import List, cast + +from flwr.client.typing import ClientAppCallable, Mod +from flwr.common import ( + ConfigsRecord, + Context, + Message, + Metadata, + MetricsRecord, + RecordSet, +) + +from .utils import make_ffn + +METRIC = "context" +COUNTER = "counter" + + +def _increment_context_counter(context: Context) -> None: + # Read from context + current_counter = cast(int, context.state.metrics_records[METRIC][COUNTER]) + # update and override context + current_counter += 1 + context.state.metrics_records[METRIC] = MetricsRecord({COUNTER: current_counter}) + + +def make_mock_mod(name: str, footprint: List[str]) -> Mod: + """Make a mock mod.""" + + def mod(message: Message, context: Context, app: ClientAppCallable) -> Message: + footprint.append(name) + # add empty ConfigRecord to in_message for this mod + message.content.configs_records[name] = ConfigsRecord() + _increment_context_counter(context) + out_message: Message = app(message, context) + footprint.append(name) + _increment_context_counter(context) + # add empty ConfigRegcord to out_message for this mod + out_message.content.configs_records[name] = ConfigsRecord() + return out_message + + return mod + + +def make_mock_app(name: str, footprint: List[str]) -> ClientAppCallable: + """Make a mock app.""" + + def app(message: Message, context: Context) -> Message: + footprint.append(name) + message.content.configs_records[name] = ConfigsRecord() + out_message = Message(metadata=message.metadata, content=RecordSet()) + out_message.content.configs_records[name] = ConfigsRecord() + print(context) + return out_message + + return app + + +def _get_dummy_flower_message() -> Message: + return Message( + content=RecordSet(), + metadata=Metadata( + run_id=0, + message_id="", + group_id="", + src_node_id=0, + dst_node_id=0, + reply_to_message="", + ttl="", + message_type="mock", + ), + ) + + +class TestMakeApp(unittest.TestCase): + """Tests for the `make_app` function.""" + + def test_multiple_mods(self) -> None: + """Test if multiple mods are called in the correct order.""" + # Prepare + footprint: List[str] = [] + mock_app = make_mock_app("app", footprint) + mock_mod_names = [f"mod{i}" for i in range(1, 15)] + mock_mods = [make_mock_mod(name, footprint) for name in mock_mod_names] + + state = RecordSet() + state.metrics_records[METRIC] = MetricsRecord({COUNTER: 0.0}) + context = Context(state=state) + message = _get_dummy_flower_message() + + # Execute + wrapped_app = make_ffn(mock_app, mock_mods) + out_message = wrapped_app(message, context) + + # Assert + trace = mock_mod_names + ["app"] + self.assertEqual(footprint, trace + list(reversed(mock_mod_names))) + # pylint: disable-next=no-member + self.assertEqual( + "".join(message.content.configs_records.keys()), "".join(trace) + ) + self.assertEqual( + "".join(out_message.content.configs_records.keys()), + "".join(reversed(trace)), + ) + self.assertEqual(state.metrics_records[METRIC][COUNTER], 2 * len(mock_mods)) + + def test_filter(self) -> None: + """Test if a mod can filter incoming TaskIns.""" + # Prepare + footprint: List[str] = [] + mock_app = make_mock_app("app", footprint) + context = Context(state=RecordSet()) + message = _get_dummy_flower_message() + + def filter_mod( + message: Message, + _1: Context, + _2: ClientAppCallable, + ) -> Message: + footprint.append("filter") + message.content.configs_records["filter"] = ConfigsRecord() + out_message = Message(metadata=message.metadata, content=RecordSet()) + out_message.content.configs_records["filter"] = ConfigsRecord() + # Skip calling app + return out_message + + # Execute + wrapped_app = make_ffn(mock_app, [filter_mod]) + out_message = wrapped_app(message, context) + + # Assert + self.assertEqual(footprint, ["filter"]) + # pylint: disable-next=no-member + self.assertEqual(list(message.content.configs_records.keys())[0], "filter") + self.assertEqual(list(out_message.content.configs_records.keys())[0], "filter") diff --git a/src/py/flwr/client/node_state.py b/src/py/flwr/client/node_state.py index 0a29be511806..71681b783419 100644 --- a/src/py/flwr/client/node_state.py +++ b/src/py/flwr/client/node_state.py @@ -17,7 +17,7 @@ from typing import Any, Dict -from flwr.client.run_state import RunState +from flwr.common import Context, RecordSet class NodeState: @@ -25,24 +25,24 @@ class NodeState: def __init__(self) -> None: self._meta: Dict[str, Any] = {} # holds metadata about the node - self.run_states: Dict[int, RunState] = {} + self.run_contexts: Dict[int, Context] = {} - def register_runstate(self, run_id: int) -> None: - """Register new run state for this node.""" - if run_id not in self.run_states: - self.run_states[run_id] = RunState({}) + def register_context(self, run_id: int) -> None: + """Register new run context for this node.""" + if run_id not in self.run_contexts: + self.run_contexts[run_id] = Context(state=RecordSet()) - def retrieve_runstate(self, run_id: int) -> RunState: - """Get run state given a run_id.""" - if run_id in self.run_states: - return self.run_states[run_id] + def retrieve_context(self, run_id: int) -> Context: + """Get run context given a run_id.""" + if run_id in self.run_contexts: + return self.run_contexts[run_id] raise RuntimeError( - f"RunState for run_id={run_id} doesn't exist." - " A run must be registered before it can be retrieved or updated " + f"Context for run_id={run_id} doesn't exist." + " A run context must be registered before it can be retrieved or updated " " by a client." ) - def update_runstate(self, run_id: int, run_state: RunState) -> None: - """Update run state.""" - self.run_states[run_id] = run_state + def update_context(self, run_id: int, context: Context) -> None: + """Update run context.""" + self.run_contexts[run_id] = context diff --git a/src/py/flwr/client/node_state_tests.py b/src/py/flwr/client/node_state_tests.py index 7bc0d77d16cf..193f52661579 100644 --- a/src/py/flwr/client/node_state_tests.py +++ b/src/py/flwr/client/node_state_tests.py @@ -15,18 +15,22 @@ """Node state tests.""" +from typing import cast + from flwr.client.node_state import NodeState -from flwr.client.run_state import RunState +from flwr.common import ConfigsRecord, Context from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611 -def _run_dummy_task(state: RunState) -> RunState: - if "counter" in state.state: - state.state["counter"] += "1" - else: - state.state["counter"] = "1" +def _run_dummy_task(context: Context) -> Context: + counter_value: str = "1" + if "counter" in context.state.configs_records.keys(): + counter_value = cast(str, context.state.configs_records["counter"]["count"]) + counter_value += "1" + + context.state.configs_records["counter"] = ConfigsRecord({"count": counter_value}) - return state + return context def test_multirun_in_node_state() -> None: @@ -43,17 +47,19 @@ def test_multirun_in_node_state() -> None: run_id = task.run_id # Register - node_state.register_runstate(run_id=run_id) + node_state.register_context(run_id=run_id) # Get run state - state = node_state.retrieve_runstate(run_id=run_id) + context = node_state.retrieve_context(run_id=run_id) # Run "task" - updated_state = _run_dummy_task(state) + updated_state = _run_dummy_task(context) # Update run state - node_state.update_runstate(run_id=run_id, run_state=updated_state) + node_state.update_context(run_id=run_id, context=updated_state) # Verify values - for run_id, state in node_state.run_states.items(): - assert state.state["counter"] == expected_values[run_id] + for run_id, context in node_state.run_contexts.items(): + assert ( + context.state.configs_records["counter"]["count"] == expected_values[run_id] + ) diff --git a/src/py/flwr/client/numpy_client.py b/src/py/flwr/client/numpy_client.py index d67fb90512d4..0247958d88a9 100644 --- a/src/py/flwr/client/numpy_client.py +++ b/src/py/flwr/client/numpy_client.py @@ -19,9 +19,9 @@ from typing import Callable, Dict, Tuple from flwr.client.client import Client -from flwr.client.run_state import RunState from flwr.common import ( Config, + Context, NDArrays, Scalar, ndarrays_to_parameters, @@ -70,7 +70,7 @@ class NumPyClient(ABC): """Abstract base class for Flower clients using NumPy.""" - state: RunState + context: Context def get_properties(self, config: Config) -> Dict[str, Scalar]: """Return a client's set of properties. @@ -174,13 +174,13 @@ def evaluate( _ = (self, parameters, config) return 0.0, 0, {} - def get_state(self) -> RunState: - """Get the run state from this client.""" - return self.state + def get_context(self) -> Context: + """Get the run context from this client.""" + return self.context - def set_state(self, state: RunState) -> None: - """Apply a run state to this client.""" - self.state = state + def set_context(self, context: Context) -> None: + """Apply a run context to this client.""" + self.context = context def to_client(self) -> Client: """Convert to object to Client type and return it.""" @@ -278,21 +278,21 @@ def _evaluate(self: Client, ins: EvaluateIns) -> EvaluateRes: ) -def _get_state(self: Client) -> RunState: - """Return state of underlying NumPyClient.""" - return self.numpy_client.get_state() # type: ignore +def _get_context(self: Client) -> Context: + """Return context of underlying NumPyClient.""" + return self.numpy_client.get_context() # type: ignore -def _set_state(self: Client, state: RunState) -> None: - """Apply state to underlying NumPyClient.""" - self.numpy_client.set_state(state) # type: ignore +def _set_context(self: Client, context: Context) -> None: + """Apply context to underlying NumPyClient.""" + self.numpy_client.set_context(context) # type: ignore def _wrap_numpy_client(client: NumPyClient) -> Client: member_dict: Dict[str, Callable] = { # type: ignore "__init__": _constructor, - "get_state": _get_state, - "set_state": _set_state, + "get_context": _get_context, + "set_context": _set_context, } # Add wrapper type methods (if overridden) diff --git a/src/py/flwr/client/rest_client/connection.py b/src/py/flwr/client/rest_client/connection.py index bb55f130f1a8..d2cc71ba3b3f 100644 --- a/src/py/flwr/client/rest_client/connection.py +++ b/src/py/flwr/client/rest_client/connection.py @@ -17,18 +17,18 @@ import sys from contextlib import contextmanager +from copy import copy from logging import ERROR, INFO, WARN from typing import Callable, Dict, Iterator, Optional, Tuple, Union, cast -from flwr.client.message_handler.task_handler import ( - configure_task_res, - get_task_ins, - validate_task_ins, - validate_task_res, -) +from flwr.client.message_handler.message_handler import validate_out_message +from flwr.client.message_handler.task_handler import get_task_ins, validate_task_ins from flwr.common import GRPC_MAX_MESSAGE_LENGTH from flwr.common.constant import MISSING_EXTRA_REST from flwr.common.logger import log +from flwr.common.message import Message, Metadata +from flwr.common.retry_invoker import RetryInvoker +from flwr.common.serde import message_from_taskins, message_to_taskres from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, CreateNodeResponse, @@ -39,7 +39,7 @@ PushTaskResResponse, ) from flwr.proto.node_pb2 import Node # pylint: disable=E0611 -from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 +from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611 try: import requests @@ -48,7 +48,7 @@ KEY_NODE = "node" -KEY_TASK_INS = "current_task_ins" +KEY_METADATA = "in_message_metadata" PATH_CREATE_NODE: str = "api/v0/fleet/create-node" @@ -62,14 +62,15 @@ def http_request_response( server_address: str, insecure: bool, # pylint: disable=unused-argument + retry_invoker: RetryInvoker, max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613 root_certificates: Optional[ Union[bytes, str] ] = None, # pylint: disable=unused-argument ) -> Iterator[ Tuple[ - Callable[[], Optional[TaskIns]], - Callable[[TaskRes], None], + Callable[[], Optional[Message]], + Callable[[Message], None], Optional[Callable[[], None]], Optional[Callable[[], None]], ] @@ -85,6 +86,12 @@ def http_request_response( The IPv6 address of the server with `http://` or `https://`. If the Flower server runs on the same machine on port 8080, then `server_address` would be `"http://[::]:8080"`. + insecure : bool + Unused argument present for compatibilty. + retry_invoker: RetryInvoker + `RetryInvoker` object that will try to reconnect the client to the server + after REST connection errors. If None, the client will only try to + reconnect once after a failure. max_message_length : int Ignored, only present to preserve API-compatibility. root_certificates : Optional[Union[bytes, str]] (default: None) @@ -120,8 +127,8 @@ def http_request_response( "must be provided as a string path to the client.", ) - # Necessary state to link TaskRes to TaskIns - state: Dict[str, Optional[TaskIns]] = {KEY_TASK_INS: None} + # Necessary state to validate messages to be sent + state: Dict[str, Optional[Metadata]] = {KEY_METADATA: None} # Enable create_node and delete_node to store node node_store: Dict[str, Optional[Node]] = {KEY_NODE: None} @@ -135,7 +142,8 @@ def create_node() -> None: create_node_req_proto = CreateNodeRequest() create_node_req_bytes: bytes = create_node_req_proto.SerializeToString() - res = requests.post( + res = retry_invoker.invoke( + requests.post, url=f"{base_url}/{PATH_CREATE_NODE}", headers={ "Accept": "application/protobuf", @@ -178,7 +186,8 @@ def delete_node() -> None: node: Node = cast(Node, node_store[KEY_NODE]) delete_node_req_proto = DeleteNodeRequest(node=node) delete_node_req_req_bytes: bytes = delete_node_req_proto.SerializeToString() - res = requests.post( + res = retry_invoker.invoke( + requests.post, url=f"{base_url}/{PATH_DELETE_NODE}", headers={ "Accept": "application/protobuf", @@ -206,7 +215,7 @@ def delete_node() -> None: PATH_PULL_TASK_INS, ) - def receive() -> Optional[TaskIns]: + def receive() -> Optional[Message]: """Receive next task from server.""" # Get Node if node_store[KEY_NODE] is None: @@ -219,7 +228,8 @@ def receive() -> Optional[TaskIns]: pull_task_ins_req_bytes: bytes = pull_task_ins_req_proto.SerializeToString() # Request instructions (task) from server - res = requests.post( + res = retry_invoker.invoke( + requests.post, url=f"{base_url}/{PATH_PULL_TASK_INS}", headers={ "Accept": "application/protobuf", @@ -256,40 +266,41 @@ def receive() -> Optional[TaskIns]: task_ins: Optional[TaskIns] = get_task_ins(pull_task_ins_response_proto) # Discard the current TaskIns if not valid - if task_ins is not None and not validate_task_ins( - task_ins, discard_reconnect_ins=True + if task_ins is not None and not ( + task_ins.task.consumer.node_id == node.node_id + and validate_task_ins(task_ins) ): task_ins = None - # Remember `task_ins` until `task_res` is available - state[KEY_TASK_INS] = task_ins - - # Return the TaskIns if available + # Return the Message if available + message = None + state[KEY_METADATA] = None if task_ins is not None: + message = message_from_taskins(task_ins) + state[KEY_METADATA] = copy(message.metadata) log(INFO, "[Node] POST /%s: success", PATH_PULL_TASK_INS) - return task_ins + return message - def send(task_res: TaskRes) -> None: + def send(message: Message) -> None: """Send task result back to server.""" # Get Node if node_store[KEY_NODE] is None: log(ERROR, "Node instance missing") return - node: Node = cast(Node, node_store[KEY_NODE]) - if state[KEY_TASK_INS] is None: - log(ERROR, "No current TaskIns") + # Get incoming message + in_metadata = state[KEY_METADATA] + if in_metadata is None: + log(ERROR, "No current message") return - task_ins: TaskIns = cast(TaskIns, state[KEY_TASK_INS]) - - # Check if fields to be set are not initialized - if not validate_task_res(task_res): - state[KEY_TASK_INS] = None - log(ERROR, "TaskRes has been initialized accidentally") + # Validate out message + if not validate_out_message(message, in_metadata): + log(ERROR, "Invalid out message") + return - # Configure TaskRes - task_res = configure_task_res(task_res, task_ins, node) + # Construct TaskRes + task_res = message_to_taskres(message) # Serialize ProtoBuf to bytes push_task_res_request_proto = PushTaskResRequest(task_res_list=[task_res]) @@ -298,7 +309,8 @@ def send(task_res: TaskRes) -> None: ) # Send ClientMessage to server - res = requests.post( + res = retry_invoker.invoke( + requests.post, url=f"{base_url}/{PATH_PUSH_TASK_RES}", headers={ "Accept": "application/protobuf", @@ -309,7 +321,7 @@ def send(task_res: TaskRes) -> None: timeout=None, ) - state[KEY_TASK_INS] = None + state[KEY_METADATA] = None # Check status code and headers if res.status_code != 200: diff --git a/src/py/flwr/client/secure_aggregation/handler.py b/src/py/flwr/client/secure_aggregation/handler.py deleted file mode 100644 index 487ed842c93f..000000000000 --- a/src/py/flwr/client/secure_aggregation/handler.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Message Handler for Secure Aggregation (abstract base class).""" - - -from abc import ABC, abstractmethod -from typing import Dict - -from flwr.common.typing import Value - - -class SecureAggregationHandler(ABC): - """Abstract base class for secure aggregation message handlers.""" - - @abstractmethod - def handle_secure_aggregation( - self, named_values: Dict[str, Value] - ) -> Dict[str, Value]: - """Handle incoming Secure Aggregation message and return results. - - Parameters - ---------- - named_values : Dict[str, Value] - The named values retrieved from the SecureAggregation sub-message - of Task message in the server's TaskIns. - - Returns - ------- - Dict[str, Value] - The final/intermediate results of the Secure Aggregation protocol. - """ diff --git a/src/py/flwr/client/secure_aggregation/secaggplus_handler.py b/src/py/flwr/client/secure_aggregation/secaggplus_handler.py deleted file mode 100644 index 4b74c1ace3de..000000000000 --- a/src/py/flwr/client/secure_aggregation/secaggplus_handler.py +++ /dev/null @@ -1,488 +0,0 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Message handler for the SecAgg+ protocol.""" - - -import os -from dataclasses import dataclass, field -from logging import ERROR, INFO, WARNING -from typing import Any, Dict, List, Optional, Tuple, Union, cast - -from flwr.client.client import Client -from flwr.client.numpy_client import NumPyClient -from flwr.common import ( - bytes_to_ndarray, - ndarray_to_bytes, - ndarrays_to_parameters, - parameters_to_ndarrays, -) -from flwr.common.logger import log -from flwr.common.secure_aggregation.crypto.shamir import create_shares -from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( - bytes_to_private_key, - bytes_to_public_key, - decrypt, - encrypt, - generate_key_pairs, - generate_shared_key, - private_key_to_bytes, - public_key_to_bytes, -) -from flwr.common.secure_aggregation.ndarrays_arithmetic import ( - factor_combine, - parameters_addition, - parameters_mod, - parameters_multiply, - parameters_subtraction, -) -from flwr.common.secure_aggregation.quantization import quantize -from flwr.common.secure_aggregation.secaggplus_constants import ( - KEY_ACTIVE_SECURE_ID_LIST, - KEY_CIPHERTEXT_LIST, - KEY_CLIPPING_RANGE, - KEY_DEAD_SECURE_ID_LIST, - KEY_DESTINATION_LIST, - KEY_MASKED_PARAMETERS, - KEY_MOD_RANGE, - KEY_PARAMETERS, - KEY_PUBLIC_KEY_1, - KEY_PUBLIC_KEY_2, - KEY_SAMPLE_NUMBER, - KEY_SECURE_ID, - KEY_SECURE_ID_LIST, - KEY_SHARE_LIST, - KEY_SHARE_NUMBER, - KEY_SOURCE_LIST, - KEY_STAGE, - KEY_TARGET_RANGE, - KEY_THRESHOLD, - STAGE_COLLECT_MASKED_INPUT, - STAGE_SETUP, - STAGE_SHARE_KEYS, - STAGE_UNMASK, - STAGES, -) -from flwr.common.secure_aggregation.secaggplus_utils import ( - pseudo_rand_gen, - share_keys_plaintext_concat, - share_keys_plaintext_separate, -) -from flwr.common.typing import FitIns, Value - -from .handler import SecureAggregationHandler - - -@dataclass -# pylint: disable-next=too-many-instance-attributes -class SecAggPlusState: - """State of the SecAgg+ protocol.""" - - sid: int = 0 - sample_num: int = 0 - share_num: int = 0 - threshold: int = 0 - clipping_range: float = 0.0 - target_range: int = 0 - mod_range: int = 0 - - # Secret key (sk) and public key (pk) - sk1: bytes = b"" - pk1: bytes = b"" - sk2: bytes = b"" - pk2: bytes = b"" - - # Random seed for generating the private mask - rd_seed: bytes = b"" - - rd_seed_share_dict: Dict[int, bytes] = field(default_factory=dict) - sk1_share_dict: Dict[int, bytes] = field(default_factory=dict) - # The dict of the shared secrets from sk2 - ss2_dict: Dict[int, bytes] = field(default_factory=dict) - public_keys_dict: Dict[int, Tuple[bytes, bytes]] = field(default_factory=dict) - - client: Optional[Union[Client, NumPyClient]] = None - - -class SecAggPlusHandler(SecureAggregationHandler): - """Message handler for the SecAgg+ protocol.""" - - _shared_state = SecAggPlusState() - _current_stage = STAGE_UNMASK - - def handle_secure_aggregation( - self, named_values: Dict[str, Value] - ) -> Dict[str, Value]: - """Handle incoming message and return results, following the SecAgg+ protocol. - - Parameters - ---------- - named_values : Dict[str, Value] - The named values retrieved from the SecureAggregation sub-message - of Task message in the server's TaskIns. - - Returns - ------- - Dict[str, Value] - The final/intermediate results of the SecAgg+ protocol. - """ - # Check if self is a client - if not isinstance(self, (Client, NumPyClient)): - raise TypeError( - "The subclass of SecAggPlusHandler must be " - "the subclass of Client or NumPyClient." - ) - - # Check the validity of the next stage - check_stage(self._current_stage, named_values) - - # Update the current stage - self._current_stage = cast(str, named_values.pop(KEY_STAGE)) - - # Check the validity of the `named_values` based on the current stage - check_named_values(self._current_stage, named_values) - - # Execute - if self._current_stage == STAGE_SETUP: - self._shared_state = SecAggPlusState(client=self) - return _setup(self._shared_state, named_values) - if self._current_stage == STAGE_SHARE_KEYS: - return _share_keys(self._shared_state, named_values) - if self._current_stage == STAGE_COLLECT_MASKED_INPUT: - return _collect_masked_input(self._shared_state, named_values) - if self._current_stage == STAGE_UNMASK: - return _unmask(self._shared_state, named_values) - raise ValueError(f"Unknown secagg stage: {self._current_stage}") - - -def check_stage(current_stage: str, named_values: Dict[str, Value]) -> None: - """Check the validity of the next stage.""" - # Check the existence of KEY_STAGE - if KEY_STAGE not in named_values: - raise KeyError( - f"The required key '{KEY_STAGE}' is missing from the input `named_values`." - ) - - # Check the value type of the KEY_STAGE - next_stage = named_values[KEY_STAGE] - if not isinstance(next_stage, str): - raise TypeError( - f"The value for the key '{KEY_STAGE}' must be of type {str}, " - f"but got {type(next_stage)} instead." - ) - - # Check the validity of the next stage - if next_stage == STAGE_SETUP: - if current_stage != STAGE_UNMASK: - log(WARNING, "Restart from the setup stage") - # If stage is not "setup", - # the stage from `named_values` should be the expected next stage - else: - expected_next_stage = STAGES[(STAGES.index(current_stage) + 1) % len(STAGES)] - if next_stage != expected_next_stage: - raise ValueError( - "Abort secure aggregation: " - f"expect {expected_next_stage} stage, but receive {next_stage} stage" - ) - - -# pylint: disable-next=too-many-branches -def check_named_values(stage: str, named_values: Dict[str, Value]) -> None: - """Check the validity of the input `named_values`.""" - # Check `named_values` for the setup stage - if stage == STAGE_SETUP: - key_type_pairs = [ - (KEY_SAMPLE_NUMBER, int), - (KEY_SECURE_ID, int), - (KEY_SHARE_NUMBER, int), - (KEY_THRESHOLD, int), - (KEY_CLIPPING_RANGE, float), - (KEY_TARGET_RANGE, int), - (KEY_MOD_RANGE, int), - ] - for key, expected_type in key_type_pairs: - if key not in named_values: - raise KeyError( - f"Stage {STAGE_SETUP}: the required key '{key}' is " - "missing from the input `named_values`." - ) - # Bool is a subclass of int in Python, - # so `isinstance(v, int)` will return True even if v is a boolean. - # pylint: disable-next=unidiomatic-typecheck - if type(named_values[key]) is not expected_type: - raise TypeError( - f"Stage {STAGE_SETUP}: The value for the key '{key}' " - f"must be of type {expected_type}, " - f"but got {type(named_values[key])} instead." - ) - elif stage == STAGE_SHARE_KEYS: - for key, value in named_values.items(): - if ( - not isinstance(value, list) - or len(value) != 2 - or not isinstance(value[0], bytes) - or not isinstance(value[1], bytes) - ): - raise TypeError( - f"Stage {STAGE_SHARE_KEYS}: " - f"the value for the key '{key}' must be a list of two bytes." - ) - elif stage == STAGE_COLLECT_MASKED_INPUT: - key_type_pairs = [ - (KEY_CIPHERTEXT_LIST, bytes), - (KEY_SOURCE_LIST, int), - (KEY_PARAMETERS, bytes), - ] - for key, expected_type in key_type_pairs: - if key not in named_values: - raise KeyError( - f"Stage {STAGE_COLLECT_MASKED_INPUT}: " - f"the required key '{key}' is " - "missing from the input `named_values`." - ) - if not isinstance(named_values[key], list) or any( - elm - for elm in cast(List[Any], named_values[key]) - # pylint: disable-next=unidiomatic-typecheck - if type(elm) is not expected_type - ): - raise TypeError( - f"Stage {STAGE_COLLECT_MASKED_INPUT}: " - f"the value for the key '{key}' " - f"must be of type List[{expected_type.__name__}]" - ) - elif stage == STAGE_UNMASK: - key_type_pairs = [ - (KEY_ACTIVE_SECURE_ID_LIST, int), - (KEY_DEAD_SECURE_ID_LIST, int), - ] - for key, expected_type in key_type_pairs: - if key not in named_values: - raise KeyError( - f"Stage {STAGE_UNMASK}: " - f"the required key '{key}' is " - "missing from the input `named_values`." - ) - if not isinstance(named_values[key], list) or any( - elm - for elm in cast(List[Any], named_values[key]) - # pylint: disable-next=unidiomatic-typecheck - if type(elm) is not expected_type - ): - raise TypeError( - f"Stage {STAGE_UNMASK}: " - f"the value for the key '{key}' " - f"must be of type List[{expected_type.__name__}]" - ) - else: - raise ValueError(f"Unknown secagg stage: {stage}") - - -def _setup(state: SecAggPlusState, named_values: Dict[str, Value]) -> Dict[str, Value]: - # Assigning parameter values to object fields - sec_agg_param_dict = named_values - state.sample_num = cast(int, sec_agg_param_dict[KEY_SAMPLE_NUMBER]) - state.sid = cast(int, sec_agg_param_dict[KEY_SECURE_ID]) - log(INFO, "Client %d: starting stage 0...", state.sid) - - state.share_num = cast(int, sec_agg_param_dict[KEY_SHARE_NUMBER]) - state.threshold = cast(int, sec_agg_param_dict[KEY_THRESHOLD]) - state.clipping_range = cast(float, sec_agg_param_dict[KEY_CLIPPING_RANGE]) - state.target_range = cast(int, sec_agg_param_dict[KEY_TARGET_RANGE]) - state.mod_range = cast(int, sec_agg_param_dict[KEY_MOD_RANGE]) - - # Dictionaries containing client secure IDs as keys - # and their respective secret shares as values. - state.rd_seed_share_dict = {} - state.sk1_share_dict = {} - # Dictionary containing client secure IDs as keys - # and their respective shared secrets (with this client) as values. - state.ss2_dict = {} - - # Create 2 sets private public key pairs - # One for creating pairwise masks - # One for encrypting message to distribute shares - sk1, pk1 = generate_key_pairs() - sk2, pk2 = generate_key_pairs() - - state.sk1, state.pk1 = private_key_to_bytes(sk1), public_key_to_bytes(pk1) - state.sk2, state.pk2 = private_key_to_bytes(sk2), public_key_to_bytes(pk2) - log(INFO, "Client %d: stage 0 completes. uploading public keys...", state.sid) - return {KEY_PUBLIC_KEY_1: state.pk1, KEY_PUBLIC_KEY_2: state.pk2} - - -# pylint: disable-next=too-many-locals -def _share_keys( - state: SecAggPlusState, named_values: Dict[str, Value] -) -> Dict[str, Value]: - named_bytes_tuples = cast(Dict[str, Tuple[bytes, bytes]], named_values) - key_dict = {int(sid): (pk1, pk2) for sid, (pk1, pk2) in named_bytes_tuples.items()} - log(INFO, "Client %d: starting stage 1...", state.sid) - state.public_keys_dict = key_dict - - # Check if the size is larger than threshold - if len(state.public_keys_dict) < state.threshold: - raise ValueError("Available neighbours number smaller than threshold") - - # Check if all public keys are unique - pk_list: List[bytes] = [] - for pk1, pk2 in state.public_keys_dict.values(): - pk_list.append(pk1) - pk_list.append(pk2) - if len(set(pk_list)) != len(pk_list): - raise ValueError("Some public keys are identical") - - # Check if public keys of this client are correct in the dictionary - if ( - state.public_keys_dict[state.sid][0] != state.pk1 - or state.public_keys_dict[state.sid][1] != state.pk2 - ): - raise ValueError( - "Own public keys are displayed in dict incorrectly, should not happen!" - ) - - # Generate the private mask seed - state.rd_seed = os.urandom(32) - - # Create shares for the private mask seed and the first private key - b_shares = create_shares(state.rd_seed, state.threshold, state.share_num) - sk1_shares = create_shares(state.sk1, state.threshold, state.share_num) - - srcs, dsts, ciphertexts = [], [], [] - - # Distribute shares - for idx, (sid, (_, pk2)) in enumerate(state.public_keys_dict.items()): - if sid == state.sid: - state.rd_seed_share_dict[state.sid] = b_shares[idx] - state.sk1_share_dict[state.sid] = sk1_shares[idx] - else: - shared_key = generate_shared_key( - bytes_to_private_key(state.sk2), - bytes_to_public_key(pk2), - ) - state.ss2_dict[sid] = shared_key - plaintext = share_keys_plaintext_concat( - state.sid, sid, b_shares[idx], sk1_shares[idx] - ) - ciphertext = encrypt(shared_key, plaintext) - srcs.append(state.sid) - dsts.append(sid) - ciphertexts.append(ciphertext) - - log(INFO, "Client %d: stage 1 completes. uploading key shares...", state.sid) - return {KEY_DESTINATION_LIST: dsts, KEY_CIPHERTEXT_LIST: ciphertexts} - - -# pylint: disable-next=too-many-locals -def _collect_masked_input( - state: SecAggPlusState, named_values: Dict[str, Value] -) -> Dict[str, Value]: - log(INFO, "Client %d: starting stage 2...", state.sid) - available_clients: List[int] = [] - ciphertexts = cast(List[bytes], named_values[KEY_CIPHERTEXT_LIST]) - srcs = cast(List[int], named_values[KEY_SOURCE_LIST]) - if len(ciphertexts) + 1 < state.threshold: - raise ValueError("Not enough available neighbour clients.") - - # Decrypt ciphertexts, verify their sources, and store shares. - for src, ciphertext in zip(srcs, ciphertexts): - shared_key = state.ss2_dict[src] - plaintext = decrypt(shared_key, ciphertext) - actual_src, dst, rd_seed_share, sk1_share = share_keys_plaintext_separate( - plaintext - ) - available_clients.append(src) - if src != actual_src: - raise ValueError( - f"Client {state.sid}: received ciphertext " - f"from {actual_src} instead of {src}." - ) - if dst != state.sid: - raise ValueError( - f"Client {state.sid}: received an encrypted message" - f"for Client {dst} from Client {src}." - ) - state.rd_seed_share_dict[src] = rd_seed_share - state.sk1_share_dict[src] = sk1_share - - # Fit client - parameters_bytes = cast(List[bytes], named_values[KEY_PARAMETERS]) - parameters = [bytes_to_ndarray(w) for w in parameters_bytes] - if isinstance(state.client, Client): - fit_res = state.client.fit( - FitIns(parameters=ndarrays_to_parameters(parameters), config={}) - ) - parameters_factor = fit_res.num_examples - parameters = parameters_to_ndarrays(fit_res.parameters) - elif isinstance(state.client, NumPyClient): - parameters, parameters_factor, _ = state.client.fit(parameters, {}) - else: - log(ERROR, "Client %d: fit function is missing.", state.sid) - - # Quantize parameter update (vector) - quantized_parameters = quantize( - parameters, state.clipping_range, state.target_range - ) - - quantized_parameters = parameters_multiply(quantized_parameters, parameters_factor) - quantized_parameters = factor_combine(parameters_factor, quantized_parameters) - - dimensions_list: List[Tuple[int, ...]] = [a.shape for a in quantized_parameters] - - # Add private mask - private_mask = pseudo_rand_gen(state.rd_seed, state.mod_range, dimensions_list) - quantized_parameters = parameters_addition(quantized_parameters, private_mask) - - for client_id in available_clients: - # Add pairwise masks - shared_key = generate_shared_key( - bytes_to_private_key(state.sk1), - bytes_to_public_key(state.public_keys_dict[client_id][0]), - ) - pairwise_mask = pseudo_rand_gen(shared_key, state.mod_range, dimensions_list) - if state.sid > client_id: - quantized_parameters = parameters_addition( - quantized_parameters, pairwise_mask - ) - else: - quantized_parameters = parameters_subtraction( - quantized_parameters, pairwise_mask - ) - - # Take mod of final weight update vector and return to server - quantized_parameters = parameters_mod(quantized_parameters, state.mod_range) - log(INFO, "Client %d: stage 2 completes. uploading masked parameters...", state.sid) - return { - KEY_MASKED_PARAMETERS: [ndarray_to_bytes(arr) for arr in quantized_parameters] - } - - -def _unmask(state: SecAggPlusState, named_values: Dict[str, Value]) -> Dict[str, Value]: - log(INFO, "Client %d: starting stage 3...", state.sid) - - active_sids = cast(List[int], named_values[KEY_ACTIVE_SECURE_ID_LIST]) - dead_sids = cast(List[int], named_values[KEY_DEAD_SECURE_ID_LIST]) - # Send private mask seed share for every avaliable client (including itclient) - # Send first private key share for building pairwise mask for every dropped client - if len(active_sids) < state.threshold: - raise ValueError("Available neighbours number smaller than threshold") - - sids, shares = [], [] - sids += active_sids - shares += [state.rd_seed_share_dict[sid] for sid in active_sids] - sids += dead_sids - shares += [state.sk1_share_dict[sid] for sid in dead_sids] - - log(INFO, "Client %d: stage 3 completes. uploading key shares...", state.sid) - return {KEY_SECURE_ID_LIST: sids, KEY_SHARE_LIST: shares} diff --git a/src/py/flwr/client/secure_aggregation/secaggplus_handler_test.py b/src/py/flwr/client/secure_aggregation/secaggplus_handler_test.py deleted file mode 100644 index 9693a46af989..000000000000 --- a/src/py/flwr/client/secure_aggregation/secaggplus_handler_test.py +++ /dev/null @@ -1,287 +0,0 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""The SecAgg+ protocol handler tests.""" - -import unittest -from itertools import product -from typing import Any, Dict, List, cast - -from flwr.client import NumPyClient -from flwr.common.secure_aggregation.secaggplus_constants import ( - KEY_ACTIVE_SECURE_ID_LIST, - KEY_CIPHERTEXT_LIST, - KEY_CLIPPING_RANGE, - KEY_DEAD_SECURE_ID_LIST, - KEY_MOD_RANGE, - KEY_PARAMETERS, - KEY_SAMPLE_NUMBER, - KEY_SECURE_ID, - KEY_SHARE_NUMBER, - KEY_SOURCE_LIST, - KEY_STAGE, - KEY_TARGET_RANGE, - KEY_THRESHOLD, - STAGE_COLLECT_MASKED_INPUT, - STAGE_SETUP, - STAGE_SHARE_KEYS, - STAGE_UNMASK, - STAGES, -) -from flwr.common.typing import Value - -from .secaggplus_handler import SecAggPlusHandler, check_named_values - - -class EmptyFlowerNumPyClient(NumPyClient, SecAggPlusHandler): - """Empty NumPyClient.""" - - -class TestSecAggPlusHandler(unittest.TestCase): - """Test the SecAgg+ protocol handler.""" - - def test_invalid_handler(self) -> None: - """Test invalid handler.""" - handler = SecAggPlusHandler() - - with self.assertRaises(TypeError): - handler.handle_secure_aggregation({}) - - def test_stage_transition(self) -> None: - """Test stage transition.""" - handler = EmptyFlowerNumPyClient() - - assert STAGES == ( - STAGE_SETUP, - STAGE_SHARE_KEYS, - STAGE_COLLECT_MASKED_INPUT, - STAGE_UNMASK, - ) - - valid_transitions = { - # From one stage to the next stage - (STAGE_UNMASK, STAGE_SETUP), - (STAGE_SETUP, STAGE_SHARE_KEYS), - (STAGE_SHARE_KEYS, STAGE_COLLECT_MASKED_INPUT), - (STAGE_COLLECT_MASKED_INPUT, STAGE_UNMASK), - # From any stage to the initial stage - # Such transitions will log a warning. - (STAGE_SETUP, STAGE_SETUP), - (STAGE_SHARE_KEYS, STAGE_SETUP), - (STAGE_COLLECT_MASKED_INPUT, STAGE_SETUP), - } - - invalid_transitions = set(product(STAGES, STAGES)).difference(valid_transitions) - - # Test valid transitions - # If the next stage is valid, the function should update the current stage - # and then raise KeyError or other exceptions when trying to execute SA. - for current_stage, next_stage in valid_transitions: - # pylint: disable-next=protected-access - handler._current_stage = current_stage - - with self.assertRaises(KeyError): - handler.handle_secure_aggregation({KEY_STAGE: next_stage}) - # pylint: disable-next=protected-access - assert handler._current_stage == next_stage - - # Test invalid transitions - # If the next stage is invalid, the function should raise ValueError - for current_stage, next_stage in invalid_transitions: - # pylint: disable-next=protected-access - handler._current_stage = current_stage - - with self.assertRaises(ValueError): - handler.handle_secure_aggregation({KEY_STAGE: next_stage}) - # pylint: disable-next=protected-access - assert handler._current_stage == current_stage - - def test_stage_setup_check(self) -> None: - """Test content checking for the setup stage.""" - handler = EmptyFlowerNumPyClient() - - valid_key_type_pairs = [ - (KEY_SAMPLE_NUMBER, int), - (KEY_SECURE_ID, int), - (KEY_SHARE_NUMBER, int), - (KEY_THRESHOLD, int), - (KEY_CLIPPING_RANGE, float), - (KEY_TARGET_RANGE, int), - (KEY_MOD_RANGE, int), - ] - - type_to_test_value: Dict[type, Value] = { - int: 10, - bool: True, - float: 1.0, - str: "test", - bytes: b"test", - } - - valid_named_values: Dict[str, Value] = { - key: type_to_test_value[value_type] - for key, value_type in valid_key_type_pairs - } - - # Test valid `named_values` - try: - check_named_values(STAGE_SETUP, valid_named_values.copy()) - # pylint: disable-next=broad-except - except Exception as exc: - self.fail(f"check_named_values() raised {type(exc)} unexpectedly!") - - # Set the stage - valid_named_values[KEY_STAGE] = STAGE_SETUP - - # Test invalid `named_values` - for key, value_type in valid_key_type_pairs: - invalid_named_values = valid_named_values.copy() - - # Test wrong value type for the key - for other_type, other_value in type_to_test_value.items(): - if other_type == value_type: - continue - invalid_named_values[key] = other_value - # pylint: disable-next=protected-access - handler._current_stage = STAGE_UNMASK - with self.assertRaises(TypeError): - handler.handle_secure_aggregation(invalid_named_values.copy()) - - # Test missing key - invalid_named_values.pop(key) - # pylint: disable-next=protected-access - handler._current_stage = STAGE_UNMASK - with self.assertRaises(KeyError): - handler.handle_secure_aggregation(invalid_named_values.copy()) - - def test_stage_share_keys_check(self) -> None: - """Test content checking for the share keys stage.""" - handler = EmptyFlowerNumPyClient() - - valid_named_values: Dict[str, Value] = { - "1": [b"public key 1", b"public key 2"], - "2": [b"public key 1", b"public key 2"], - "3": [b"public key 1", b"public key 2"], - } - - # Test valid `named_values` - try: - check_named_values(STAGE_SHARE_KEYS, valid_named_values.copy()) - # pylint: disable-next=broad-except - except Exception as exc: - self.fail(f"check_named_values() raised {type(exc)} unexpectedly!") - - # Set the stage - valid_named_values[KEY_STAGE] = STAGE_SHARE_KEYS - - # Test invalid `named_values` - invalid_values: List[Value] = [ - b"public key 1", - [b"public key 1"], - [b"public key 1", b"public key 2", b"public key 3"], - ] - - for value in invalid_values: - invalid_named_values = valid_named_values.copy() - invalid_named_values["1"] = value - - # pylint: disable-next=protected-access - handler._current_stage = STAGE_SETUP - with self.assertRaises(TypeError): - handler.handle_secure_aggregation(invalid_named_values.copy()) - - def test_stage_collect_masked_input_check(self) -> None: - """Test content checking for the collect masked input stage.""" - handler = EmptyFlowerNumPyClient() - - valid_named_values: Dict[str, Value] = { - KEY_CIPHERTEXT_LIST: [b"ctxt!", b"ctxt@", b"ctxt#", b"ctxt?"], - KEY_SOURCE_LIST: [32, 51324, 32324123, -3], - KEY_PARAMETERS: [b"params1", b"params2"], - } - - # Test valid `named_values` - try: - check_named_values(STAGE_COLLECT_MASKED_INPUT, valid_named_values.copy()) - # pylint: disable-next=broad-except - except Exception as exc: - self.fail(f"check_named_values() raised {type(exc)} unexpectedly!") - - # Set the stage - valid_named_values[KEY_STAGE] = STAGE_COLLECT_MASKED_INPUT - - # Test invalid `named_values` - # Test missing keys - for key in list(valid_named_values.keys()): - if key == KEY_STAGE: - continue - invalid_named_values = valid_named_values.copy() - invalid_named_values.pop(key) - # pylint: disable-next=protected-access - handler._current_stage = STAGE_SHARE_KEYS - with self.assertRaises(KeyError): - handler.handle_secure_aggregation(invalid_named_values) - - # Test wrong value type for the key - for key in valid_named_values: - if key == KEY_STAGE: - continue - invalid_named_values = valid_named_values.copy() - cast(List[Any], invalid_named_values[key]).append(3.1415926) - # pylint: disable-next=protected-access - handler._current_stage = STAGE_SHARE_KEYS - with self.assertRaises(TypeError): - handler.handle_secure_aggregation(invalid_named_values) - - def test_stage_unmask_check(self) -> None: - """Test content checking for the unmasking stage.""" - handler = EmptyFlowerNumPyClient() - - valid_named_values: Dict[str, Value] = { - KEY_ACTIVE_SECURE_ID_LIST: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], - KEY_DEAD_SECURE_ID_LIST: [32, 51324, 32324123, -3], - } - - # Test valid `named_values` - try: - check_named_values(STAGE_UNMASK, valid_named_values.copy()) - # pylint: disable-next=broad-except - except Exception as exc: - self.fail(f"check_named_values() raised {type(exc)} unexpectedly!") - - # Set the stage - valid_named_values[KEY_STAGE] = STAGE_UNMASK - - # Test invalid `named_values` - # Test missing keys - for key in list(valid_named_values.keys()): - if key == KEY_STAGE: - continue - invalid_named_values = valid_named_values.copy() - invalid_named_values.pop(key) - # pylint: disable-next=protected-access - handler._current_stage = STAGE_COLLECT_MASKED_INPUT - with self.assertRaises(KeyError): - handler.handle_secure_aggregation(invalid_named_values) - - # Test wrong value type for the key - for key in valid_named_values: - if key == KEY_STAGE: - continue - invalid_named_values = valid_named_values.copy() - cast(List[Any], invalid_named_values[key]).append(True) - # pylint: disable-next=protected-access - handler._current_stage = STAGE_COLLECT_MASKED_INPUT - with self.assertRaises(TypeError): - handler.handle_secure_aggregation(invalid_named_values) diff --git a/src/py/flwr/client/typing.py b/src/py/flwr/client/typing.py index 5291afb83d98..956ac7a15c05 100644 --- a/src/py/flwr/client/typing.py +++ b/src/py/flwr/client/typing.py @@ -14,31 +14,15 @@ # ============================================================================== """Custom types for Flower clients.""" -from dataclasses import dataclass + from typing import Callable -from flwr.client.run_state import RunState -from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 +from flwr.common import Context, Message from .client import Client as Client - -@dataclass -class Fwd: - """.""" - - task_ins: TaskIns - state: RunState - - -@dataclass -class Bwd: - """.""" - - task_res: TaskRes - state: RunState - - -FlowerCallable = Callable[[Fwd], Bwd] +# Compatibility ClientFn = Callable[[str], Client] -Layer = Callable[[Fwd, FlowerCallable], Bwd] + +ClientAppCallable = Callable[[Message, Context], Message] +Mod = Callable[[Message, Context, ClientAppCallable], Message] diff --git a/src/py/flwr/common/__init__.py b/src/py/flwr/common/__init__.py index 2f45de45dfc3..9f9ff7ebc68a 100644 --- a/src/py/flwr/common/__init__.py +++ b/src/py/flwr/common/__init__.py @@ -15,14 +15,26 @@ """Common components shared between server and client.""" +from .constant import MessageType as MessageType +from .constant import MessageTypeLegacy as MessageTypeLegacy +from .context import Context as Context from .date import now as now from .grpc import GRPC_MAX_MESSAGE_LENGTH from .logger import configure as configure from .logger import log as log +from .message import Error as Error +from .message import Message as Message +from .message import Metadata as Metadata from .parameter import bytes_to_ndarray as bytes_to_ndarray from .parameter import ndarray_to_bytes as ndarray_to_bytes from .parameter import ndarrays_to_parameters as ndarrays_to_parameters from .parameter import parameters_to_ndarrays as parameters_to_ndarrays +from .record import Array as Array +from .record import ConfigsRecord as ConfigsRecord +from .record import MetricsRecord as MetricsRecord +from .record import ParametersRecord as ParametersRecord +from .record import RecordSet as RecordSet +from .record import array_from_numpy as array_from_numpy from .telemetry import EventType as EventType from .telemetry import event as event from .typing import ClientMessage as ClientMessage @@ -49,11 +61,15 @@ from .typing import Status as Status __all__ = [ + "Array", + "array_from_numpy", "bytes_to_ndarray", "ClientMessage", "Code", "Config", + "ConfigsRecord", "configure", + "Context", "DisconnectRes", "EvaluateIns", "EvaluateRes", @@ -61,14 +77,20 @@ "EventType", "FitIns", "FitRes", + "Error", "GetParametersIns", "GetParametersRes", "GetPropertiesIns", "GetPropertiesRes", "GRPC_MAX_MESSAGE_LENGTH", "log", + "Message", + "MessageType", + "MessageTypeLegacy", + "Metadata", "Metrics", "MetricsAggregationFn", + "MetricsRecord", "ndarray_to_bytes", "now", "NDArray", @@ -76,8 +98,10 @@ "ndarrays_to_parameters", "Parameters", "parameters_to_ndarrays", + "ParametersRecord", "Properties", "ReconnectIns", + "RecordSet", "Scalar", "ServerMessage", "Status", diff --git a/src/py/flwr/common/address_test.py b/src/py/flwr/common/address_test.py index c12dd5fd289e..420b89871d69 100644 --- a/src/py/flwr/common/address_test.py +++ b/src/py/flwr/common/address_test.py @@ -109,10 +109,10 @@ def test_domain_correct() -> None: """Test if a correct domain address is correctly parsed.""" # Prepare addresses = [ - ("flower.dev:123", ("flower.dev", 123, None)), - ("sub.flower.dev:123", ("sub.flower.dev", 123, None)), - ("sub2.sub1.flower.dev:123", ("sub2.sub1.flower.dev", 123, None)), - ("s5.s4.s3.s2.s1.flower.dev:123", ("s5.s4.s3.s2.s1.flower.dev", 123, None)), + ("flower.ai:123", ("flower.ai", 123, None)), + ("sub.flower.ai:123", ("sub.flower.ai", 123, None)), + ("sub2.sub1.flower.ai:123", ("sub2.sub1.flower.ai", 123, None)), + ("s5.s4.s3.s2.s1.flower.ai:123", ("s5.s4.s3.s2.s1.flower.ai", 123, None)), ("localhost:123", ("localhost", 123, None)), ("https://localhost:123", ("https://localhost", 123, None)), ("http://localhost:123", ("http://localhost", 123, None)), @@ -130,8 +130,8 @@ def test_domain_incorrect() -> None: """Test if an incorrect domain address returns None.""" # Prepare addresses = [ - "flower.dev", - "flower.dev:65536", + "flower.ai", + "flower.ai:65536", ] for address in addresses: diff --git a/src/py/flwr/common/configsrecord.py b/src/py/flwr/common/configsrecord.py deleted file mode 100644 index b0480841e06c..000000000000 --- a/src/py/flwr/common/configsrecord.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright 2024 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""ConfigsRecord.""" - - -from dataclasses import dataclass, field -from typing import Dict, Optional, get_args - -from .typing import ConfigsRecordValues, ConfigsScalar - - -@dataclass -class ConfigsRecord: - """Configs record.""" - - data: Dict[str, ConfigsRecordValues] = field(default_factory=dict) - - def __init__( - self, - configs_dict: Optional[Dict[str, ConfigsRecordValues]] = None, - keep_input: bool = True, - ): - """Construct a ConfigsRecord object. - - Parameters - ---------- - configs_dict : Optional[Dict[str, ConfigsRecordValues]] - A dictionary that stores basic types (i.e. `str`, `int`, `float`, `bytes` as - defined in `ConfigsScalar`) and lists of such types (see - `ConfigsScalarList`). - keep_input : bool (default: True) - A boolean indicating whether config passed should be deleted from the input - dictionary immediately after adding them to the record. When set - to True, the data is duplicated in memory. If memory is a concern, set - it to False. - """ - self.data = {} - if configs_dict: - self.set_configs(configs_dict, keep_input=keep_input) - - def set_configs( - self, configs_dict: Dict[str, ConfigsRecordValues], keep_input: bool = True - ) -> None: - """Add configs to the record. - - Parameters - ---------- - configs_dict : Dict[str, ConfigsRecordValues] - A dictionary that stores basic types (i.e. `str`,`int`, `float`, `bytes` as - defined in `ConfigsRecordValues`) and list of such types (see - `ConfigsScalarList`). - keep_input : bool (default: True) - A boolean indicating whether config passed should be deleted from the input - dictionary immediately after adding them to the record. When set - to True, the data is duplicated in memory. If memory is a concern, set - it to False. - """ - if any(not isinstance(k, str) for k in configs_dict.keys()): - raise TypeError(f"Not all keys are of valid type. Expected {str}") - - def is_valid(value: ConfigsScalar) -> None: - """Check if value is of expected type.""" - if not isinstance(value, get_args(ConfigsScalar)): - raise TypeError( - "Not all values are of valid type." - f" Expected {ConfigsRecordValues} but you passed {type(value)}." - ) - - # Check types of values - # Split between those values that are list and those that aren't - # then process in the same way - for value in configs_dict.values(): - if isinstance(value, list): - # If your lists are large (e.g. 1M+ elements) this will be slow - # 1s to check 10M element list on a M2 Pro - # In such settings, you'd be better of treating such config as - # an array and pass it to a ParametersRecord. - # Empty lists are valid - if len(value) > 0: - is_valid(value[0]) - # all elements in the list must be of the same valid type - # this is needed for protobuf - value_type = type(value[0]) - if not all(isinstance(v, value_type) for v in value): - raise TypeError( - "All values in a list must be of the same valid type. " - f"One of {ConfigsScalar}." - ) - else: - is_valid(value) - - # Add configs to record - if keep_input: - # Copy - self.data = configs_dict.copy() - else: - # Add entries to dataclass without duplicating memory - for key in list(configs_dict.keys()): - self.data[key] = configs_dict[key] - del configs_dict[key] - - def __getitem__(self, key: str) -> ConfigsRecordValues: - """Retrieve an element stored in record.""" - return self.data[key] diff --git a/src/py/flwr/common/constant.py b/src/py/flwr/common/constant.py index 49802f2815be..7d30a10f5881 100644 --- a/src/py/flwr/common/constant.py +++ b/src/py/flwr/common/constant.py @@ -15,6 +15,8 @@ """Flower constants.""" +from __future__ import annotations + MISSING_EXTRA_REST = """ Extra dependencies required for using the REST-based Fleet API are missing. @@ -26,8 +28,43 @@ TRANSPORT_TYPE_GRPC_BIDI = "grpc-bidi" TRANSPORT_TYPE_GRPC_RERE = "grpc-rere" TRANSPORT_TYPE_REST = "rest" +TRANSPORT_TYPE_VCE = "vce" TRANSPORT_TYPES = [ TRANSPORT_TYPE_GRPC_BIDI, TRANSPORT_TYPE_GRPC_RERE, TRANSPORT_TYPE_REST, + TRANSPORT_TYPE_VCE, ] + + +class MessageType: + """Message type.""" + + TRAIN = "train" + EVALUATE = "evaluate" + QUERY = "query" + + def __new__(cls) -> MessageType: + """Prevent instantiation.""" + raise TypeError(f"{cls.__name__} cannot be instantiated.") + + +class MessageTypeLegacy: + """Legacy message type.""" + + GET_PROPERTIES = "get_properties" + GET_PARAMETERS = "get_parameters" + + def __new__(cls) -> MessageTypeLegacy: + """Prevent instantiation.""" + raise TypeError(f"{cls.__name__} cannot be instantiated.") + + +class SType: + """Serialisation type.""" + + NUMPY = "numpy.ndarray" + + def __new__(cls) -> SType: + """Prevent instantiation.""" + raise TypeError(f"{cls.__name__} cannot be instantiated.") diff --git a/src/py/flwr/common/context.py b/src/py/flwr/common/context.py new file mode 100644 index 000000000000..b6349307d150 --- /dev/null +++ b/src/py/flwr/common/context.py @@ -0,0 +1,38 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Context.""" + + +from dataclasses import dataclass + +from .record import RecordSet + + +@dataclass +class Context: + """State of your run. + + Parameters + ---------- + state : RecordSet + Holds records added by the entity in a given run and that will stay local. + This means that the data it holds will never leave the system it's running from. + This can be used as an intermediate storage or scratchpad when + executing mods. It can also be used as a memory to access + at different points during the lifecycle of this entity (e.g. across + multiple rounds) + """ + + state: RecordSet diff --git a/src/py/flwr/common/differential_privacy.py b/src/py/flwr/common/differential_privacy.py new file mode 100644 index 000000000000..85dc198ef8a0 --- /dev/null +++ b/src/py/flwr/common/differential_privacy.py @@ -0,0 +1,176 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility functions for differential privacy.""" + + +from logging import WARNING +from typing import Optional, Tuple + +import numpy as np + +from flwr.common import ( + NDArrays, + Parameters, + ndarrays_to_parameters, + parameters_to_ndarrays, +) +from flwr.common.logger import log + + +def get_norm(input_arrays: NDArrays) -> float: + """Compute the L2 norm of the flattened input.""" + array_norms = [np.linalg.norm(array.flat) for array in input_arrays] + # pylint: disable=consider-using-generator + return float(np.sqrt(sum([norm**2 for norm in array_norms]))) + + +def add_gaussian_noise_inplace(input_arrays: NDArrays, std_dev: float) -> None: + """Add Gaussian noise to each element of the input arrays.""" + for array in input_arrays: + array += np.random.normal(0, std_dev, array.shape) + + +def clip_inputs_inplace(input_arrays: NDArrays, clipping_norm: float) -> None: + """Clip model update based on the clipping norm in-place. + + FlatClip method of the paper: https://arxiv.org/abs/1710.06963 + """ + input_norm = get_norm(input_arrays) + scaling_factor = min(1, clipping_norm / input_norm) + for array in input_arrays: + array *= scaling_factor + + +def compute_stdv( + noise_multiplier: float, clipping_norm: float, num_sampled_clients: int +) -> float: + """Compute standard deviation for noise addition. + + Paper: https://arxiv.org/abs/1710.06963 + """ + return float((noise_multiplier * clipping_norm) / num_sampled_clients) + + +def compute_clip_model_update( + param1: NDArrays, param2: NDArrays, clipping_norm: float +) -> None: + """Compute model update (param1 - param2) and clip it. + + Then add the clipped value to param1.""" + model_update = [np.subtract(x, y) for (x, y) in zip(param1, param2)] + clip_inputs_inplace(model_update, clipping_norm) + + for i, _ in enumerate(param2): + param1[i] = param2[i] + model_update[i] + + +def adaptive_clip_inputs_inplace(input_arrays: NDArrays, clipping_norm: float) -> bool: + """Clip model update based on the clipping norm in-place. + + It returns true if scaling_factor < 1 which is used for norm_bit + FlatClip method of the paper: https://arxiv.org/abs/1710.06963 + """ + input_norm = get_norm(input_arrays) + scaling_factor = min(1, clipping_norm / input_norm) + for array in input_arrays: + array *= scaling_factor + return scaling_factor < 1 + + +def compute_adaptive_clip_model_update( + param1: NDArrays, param2: NDArrays, clipping_norm: float +) -> bool: + """Compute model update, clip it, then add the clipped value to param1. + + model update = param1 - param2 + Return the norm_bit + """ + model_update = [np.subtract(x, y) for (x, y) in zip(param1, param2)] + norm_bit = adaptive_clip_inputs_inplace(model_update, clipping_norm) + + for i, _ in enumerate(param2): + param1[i] = param2[i] + model_update[i] + + return norm_bit + + +def add_gaussian_noise_to_params( + model_params: Parameters, + noise_multiplier: float, + clipping_norm: float, + num_sampled_clients: int, +) -> Parameters: + """Add gaussian noise to model parameters.""" + model_params_ndarrays = parameters_to_ndarrays(model_params) + add_gaussian_noise_inplace( + model_params_ndarrays, + compute_stdv(noise_multiplier, clipping_norm, num_sampled_clients), + ) + return ndarrays_to_parameters(model_params_ndarrays) + + +def compute_adaptive_noise_params( + noise_multiplier: float, + num_sampled_clients: float, + clipped_count_stddev: Optional[float], +) -> Tuple[float, float]: + """Compute noising parameters for the adaptive clipping. + + Paper: https://arxiv.org/abs/1905.03871 + """ + if noise_multiplier > 0: + if clipped_count_stddev is None: + clipped_count_stddev = num_sampled_clients / 20 + if noise_multiplier >= 2 * clipped_count_stddev: + raise ValueError( + f"If not specified, `clipped_count_stddev` is set to " + f"`num_sampled_clients`/20 by default. This value " + f"({num_sampled_clients / 20}) is too low to achieve the " + f"desired effective `noise_multiplier` ({noise_multiplier}). " + f"Consider increasing `clipped_count_stddev` or decreasing " + f"`noise_multiplier`." + ) + noise_multiplier_value = ( + noise_multiplier ** (-2) - (2 * clipped_count_stddev) ** (-2) + ) ** -0.5 + + adding_noise = noise_multiplier_value / noise_multiplier + if adding_noise >= 2: + log( + WARNING, + "A significant amount of noise (%s) has to be " + "added. Consider increasing `clipped_count_stddev` or " + "`num_sampled_clients`.", + adding_noise, + ) + + else: + if clipped_count_stddev is None: + clipped_count_stddev = 0.0 + noise_multiplier_value = 0.0 + + return clipped_count_stddev, noise_multiplier_value + + +def add_localdp_gaussian_noise_to_params( + model_params: Parameters, sensitivity: float, epsilon: float, delta: float +) -> Parameters: + """Add local DP gaussian noise to model parameters.""" + model_params_ndarrays = parameters_to_ndarrays(model_params) + add_gaussian_noise_inplace( + model_params_ndarrays, + sensitivity * np.sqrt(2 * np.log(1.25 / delta)) / epsilon, + ) + return ndarrays_to_parameters(model_params_ndarrays) diff --git a/src/py/flwr/common/differential_privacy_constants.py b/src/py/flwr/common/differential_privacy_constants.py new file mode 100644 index 000000000000..415759dfdf0b --- /dev/null +++ b/src/py/flwr/common/differential_privacy_constants.py @@ -0,0 +1,25 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Constants for differential privacy.""" + + +KEY_CLIPPING_NORM = "clipping_norm" +KEY_NORM_BIT = "norm_bit" +CLIENTS_DISCREPANCY_WARNING = ( + "The number of clients returning parameters (%s)" + " differs from the number of sampled clients (%s)." + " This could impact the differential privacy guarantees," + " potentially leading to privacy leakage or inadequate noise calibration." +) diff --git a/src/py/flwr/common/differential_privacy_test.py b/src/py/flwr/common/differential_privacy_test.py new file mode 100644 index 000000000000..ea2e193df2f3 --- /dev/null +++ b/src/py/flwr/common/differential_privacy_test.py @@ -0,0 +1,209 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Differential Privacy (DP) utility functions tests.""" + + +import numpy as np + +from .differential_privacy import ( + add_gaussian_noise_inplace, + clip_inputs_inplace, + compute_adaptive_noise_params, + compute_clip_model_update, + compute_stdv, + get_norm, +) + + +def test_add_gaussian_noise_inplace() -> None: + """Test add_gaussian_noise_inplace function.""" + # Prepare + update = [np.array([[1.0, 2.0], [3.0, 4.0]]), np.array([[5.0, 6.0], [7.0, 8.0]])] + std_dev = 0.1 + + # Execute + add_gaussian_noise_inplace(update, std_dev) + + # Assert + # Check that the shape of the result is the same as the input + for layer in update: + assert layer.shape == (2, 2) + + # Check that the values have been changed and are not equal to the original update + for layer in update: + assert not np.array_equal( + layer, [[1.0, 2.0], [3.0, 4.0]] + ) and not np.array_equal(layer, [[5.0, 6.0], [7.0, 8.0]]) + + # Check that the noise has been added + for layer in update: + noise_added = ( + layer - np.array([[1.0, 2.0], [3.0, 4.0]]) + if np.array_equal(layer, [[1.0, 2.0], [3.0, 4.0]]) + else layer - np.array([[5.0, 6.0], [7.0, 8.0]]) + ) + assert np.any(np.abs(noise_added) > 0) + + +def test_get_norm() -> None: + """Test get_norm function.""" + # Prepare + update = [np.array([[1, 2], [3, 4]]), np.array([[5, 6], [7, 8]])] + + # Execute + result = get_norm(update) + + expected = float( + np.linalg.norm(np.concatenate([sub_update.flatten() for sub_update in update])) + ) + + # Assert + assert expected == result + + +def test_clip_inputs_inplace() -> None: + """Test clip_inputs_inplace function.""" + # Prepare + updates = [ + np.array([[1.5, -0.5], [2.0, -1.0]]), + np.array([0.5, -0.5]), + np.array([[-0.5, 1.5], [-1.0, 2.0]]), + np.array([-0.5, 0.5]), + ] + clipping_norm = 1.5 + + original_updates = [np.copy(update) for update in updates] + + # Execute + clip_inputs_inplace(updates, clipping_norm) + + # Assert + for updated, original_update in zip(updates, original_updates): + clip_norm = np.linalg.norm(original_update) + assert np.all(updated <= clip_norm) and np.all(updated >= -clip_norm) + + +def test_compute_stdv() -> None: + """Test compute_stdv function.""" + # Prepare + noise_multiplier = 1.0 + clipping_norm = 0.5 + num_sampled_clients = 10 + + # Execute + stdv = compute_stdv(noise_multiplier, clipping_norm, num_sampled_clients) + + # Assert + expected_stdv = float((noise_multiplier * clipping_norm) / num_sampled_clients) + assert stdv == expected_stdv + + +def test_compute_clip_model_update() -> None: + """Test compute_clip_model_update function.""" + # Prepare + param1 = [ + np.array([0.5, 1.5, 2.5]), + np.array([3.5, 4.5, 5.5]), + np.array([6.5, 7.5, 8.5]), + ] + param2 = [ + np.array([1.0, 2.0, 3.0]), + np.array([4.0, 5.0, 6.0]), + np.array([7.0, 8.0, 9.0]), + ] + clipping_norm = 4 + + expected_result = [ + np.array([0.5, 1.5, 2.5]), + np.array([3.5, 4.5, 5.5]), + np.array([6.5, 7.5, 8.5]), + ] + + # Execute + compute_clip_model_update(param1, param2, clipping_norm) + + # Assert + for i, param in enumerate(param1): + np.testing.assert_array_almost_equal(param, expected_result[i]) + + +def test_compute_adaptive_noise_params() -> None: + """Test compute_adaptive_noise_params function.""" + # Test valid input with positive noise_multiplier + noise_multiplier = 1.0 + num_sampled_clients = 100.0 + clipped_count_stddev = None + result = compute_adaptive_noise_params( + noise_multiplier, num_sampled_clients, clipped_count_stddev + ) + + # Assert + assert isinstance(result, tuple) + assert len(result) == 2 + assert result[0] > 0.0 + assert result[1] > 0.0 + + # Test valid input with zero noise_multiplier + noise_multiplier = 0.0 + num_sampled_clients = 50.0 + clipped_count_stddev = None + result = compute_adaptive_noise_params( + noise_multiplier, num_sampled_clients, clipped_count_stddev + ) + + # Assert + assert isinstance(result, tuple) + assert len(result) == 2 + assert result[0] == 0.0 + assert result[1] == 0.0 + + # Test valid input with specified clipped_count_stddev + noise_multiplier = 3.0 + num_sampled_clients = 200.0 + clipped_count_stddev = 5.0 + result = compute_adaptive_noise_params( + noise_multiplier, num_sampled_clients, clipped_count_stddev + ) + + # Assert + assert isinstance(result, tuple) + assert len(result) == 2 + assert result[0] == clipped_count_stddev + assert result[1] > 0.0 + + # Test invalid input with noise_multiplier >= 2 * clipped_count_stddev + noise_multiplier = 10.0 + num_sampled_clients = 100.0 + clipped_count_stddev = None + try: + compute_adaptive_noise_params( + noise_multiplier, num_sampled_clients, clipped_count_stddev + ) + except ValueError: + pass + else: + raise AssertionError("Expected ValueError not raised.") + + # Test intermediate calculation + noise_multiplier = 3.0 + num_sampled_clients = 200.0 + clipped_count_stddev = 5.0 + result = compute_adaptive_noise_params( + noise_multiplier, num_sampled_clients, clipped_count_stddev + ) + temp_value = (noise_multiplier ** (-2) - (2 * clipped_count_stddev) ** (-2)) ** -0.5 + + # Assert + assert np.isclose(result[1], temp_value, rtol=1e-6) diff --git a/src/py/flwr/common/dp.py b/src/py/flwr/common/dp.py index 5030ad34805b..83a72b8ce749 100644 --- a/src/py/flwr/common/dp.py +++ b/src/py/flwr/common/dp.py @@ -19,11 +19,13 @@ import numpy as np +from flwr.common.logger import warn_deprecated_feature from flwr.common.typing import NDArrays # Calculates the L2-norm of a potentially ragged array def _get_update_norm(update: NDArrays) -> float: + warn_deprecated_feature("`_get_update_norm` method") flattened_update = update[0] for i in range(1, len(update)): flattened_update = np.append(flattened_update, update[i]) @@ -32,6 +34,7 @@ def _get_update_norm(update: NDArrays) -> float: def add_gaussian_noise(update: NDArrays, std_dev: float) -> NDArrays: """Add iid Gaussian noise to each floating point value in the update.""" + warn_deprecated_feature("`add_gaussian_noise` method") update_noised = [ layer + np.random.normal(0, std_dev, layer.shape) for layer in update ] @@ -40,6 +43,7 @@ def add_gaussian_noise(update: NDArrays, std_dev: float) -> NDArrays: def clip_by_l2(update: NDArrays, threshold: float) -> Tuple[NDArrays, bool]: """Scales the update so thats its L2 norm is upper-bound to threshold.""" + warn_deprecated_feature("`clip_by_l2` method") update_norm = _get_update_norm(update) scaling_factor = min(1, threshold / update_norm) update_clipped: NDArrays = [layer * scaling_factor for layer in update] diff --git a/src/py/flwr/common/exit_handlers.py b/src/py/flwr/common/exit_handlers.py new file mode 100644 index 000000000000..30750c28a450 --- /dev/null +++ b/src/py/flwr/common/exit_handlers.py @@ -0,0 +1,87 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Common function to register exit handlers for server and client.""" + + +import sys +from signal import SIGINT, SIGTERM, signal +from threading import Thread +from types import FrameType +from typing import List, Optional + +from grpc import Server + +from flwr.common.telemetry import EventType, event + + +def register_exit_handlers( + event_type: EventType, + grpc_servers: Optional[List[Server]] = None, + bckg_threads: Optional[List[Thread]] = None, +) -> None: + """Register exit handlers for `SIGINT` and `SIGTERM` signals. + + Parameters + ---------- + event_type : EventType + The telemetry event that should be logged before exit. + grpc_servers: Optional[List[Server]] (default: None) + An otpional list of gRPC servers that need to be gracefully + terminated before exiting. + bckg_threads: Optional[List[Thread]] (default: None) + An optional list of threads that need to be gracefully + terminated before exiting. + """ + default_handlers = { + SIGINT: None, + SIGTERM: None, + } + + def graceful_exit_handler( # type: ignore + signalnum, + frame: FrameType, # pylint: disable=unused-argument + ) -> None: + """Exit handler to be registered with `signal.signal`. + + When called will reset signal handler to original signal handler from + default_handlers. + """ + # Reset to default handler + signal(signalnum, default_handlers[signalnum]) + + event_res = event(event_type=event_type) + + if grpc_servers is not None: + for grpc_server in grpc_servers: + grpc_server.stop(grace=1) + + if bckg_threads is not None: + for bckg_thread in bckg_threads: + bckg_thread.join() + + # Ensure event has happend + event_res.result() + + # Setup things for graceful exit + sys.exit(0) + + default_handlers[SIGINT] = signal( # type: ignore + SIGINT, + graceful_exit_handler, # type: ignore + ) + default_handlers[SIGTERM] = signal( # type: ignore + SIGTERM, + graceful_exit_handler, # type: ignore + ) diff --git a/src/py/flwr/common/flowercontext.py b/src/py/flwr/common/flowercontext.py deleted file mode 100644 index 6e26d93bfe9a..000000000000 --- a/src/py/flwr/common/flowercontext.py +++ /dev/null @@ -1,77 +0,0 @@ -# Copyright 2024 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""FlowerContext and Metadata.""" - - -from dataclasses import dataclass - -from .recordset import RecordSet - - -@dataclass -class Metadata: - """A dataclass holding metadata associated with the current task. - - Parameters - ---------- - run_id : int - An identifier for the current run. - task_id : str - An identifier for the current task. - group_id : str - An identifier for grouping tasks. In some settings - this is used as the FL round. - ttl : str - Time-to-live for this task. - task_type : str - A string that encodes the action to be executed on - the receiving end. - """ - - run_id: int - task_id: str - group_id: str - ttl: str - task_type: str - - -@dataclass -class FlowerContext: - """State of your application from the viewpoint of the entity using it. - - Parameters - ---------- - in_message : RecordSet - Holds records sent by another entity (e.g. sent by the server-side - logic to a client, or vice-versa) - out_message : RecordSet - Holds records added by the current entity. This `RecordSet` will - be sent out (e.g. back to the server-side for aggregation of - parameter, or to the client to perform a certain task) - local : RecordSet - Holds record added by the current entity and that will stay local. - This means that the data it holds will never leave the system it's running from. - This can be used as an intermediate storage or scratchpad when - executing middleware layers. It can also be used as a memory to access - at different points during the lifecycle of this entity (e.g. across - multiple rounds) - metadata : Metadata - A dataclass including information about the task to be executed. - """ - - in_message: RecordSet - out_message: RecordSet - local: RecordSet - metadata: Metadata diff --git a/src/py/flwr/common/grpc.py b/src/py/flwr/common/grpc.py index 9d0543ea8c75..7d0eba078ab0 100644 --- a/src/py/flwr/common/grpc.py +++ b/src/py/flwr/common/grpc.py @@ -15,7 +15,7 @@ """Utility functions for gRPC.""" -from logging import INFO +from logging import DEBUG from typing import Optional import grpc @@ -49,12 +49,12 @@ def create_channel( if insecure: channel = grpc.insecure_channel(server_address, options=channel_options) - log(INFO, "Opened insecure gRPC connection (no certificates were passed)") + log(DEBUG, "Opened insecure gRPC connection (no certificates were passed)") else: ssl_channel_credentials = grpc.ssl_channel_credentials(root_certificates) channel = grpc.secure_channel( server_address, ssl_channel_credentials, options=channel_options ) - log(INFO, "Opened secure gRPC connection using certificates") + log(DEBUG, "Opened secure gRPC connection using certificates") return channel diff --git a/src/py/flwr/common/logger.py b/src/py/flwr/common/logger.py index 50c902da38b5..2bc41773ed61 100644 --- a/src/py/flwr/common/logger.py +++ b/src/py/flwr/common/logger.py @@ -18,21 +18,86 @@ import logging from logging import WARN, LogRecord from logging.handlers import HTTPHandler -from typing import Any, Dict, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, Optional, TextIO, Tuple # Create logger LOGGER_NAME = "flwr" FLOWER_LOGGER = logging.getLogger(LOGGER_NAME) FLOWER_LOGGER.setLevel(logging.DEBUG) -DEFAULT_FORMATTER = logging.Formatter( - "%(levelname)s %(name)s %(asctime)s | %(filename)s:%(lineno)d | %(message)s" -) +LOG_COLORS = { + "DEBUG": "\033[94m", # Blue + "INFO": "\033[92m", # Green + "WARNING": "\033[93m", # Yellow + "ERROR": "\033[91m", # Red + "CRITICAL": "\033[95m", # Magenta + "RESET": "\033[0m", # Reset to default +} + +if TYPE_CHECKING: + StreamHandler = logging.StreamHandler[Any] +else: + StreamHandler = logging.StreamHandler + + +class ConsoleHandler(StreamHandler): + """Console handler that allows configurable formatting.""" + + def __init__( + self, + timestamps: bool = False, + json: bool = False, + colored: bool = True, + stream: Optional[TextIO] = None, + ) -> None: + super().__init__(stream) + self.timestamps = timestamps + self.json = json + self.colored = colored + + def emit(self, record: LogRecord) -> None: + """Emit a record.""" + if self.json: + record.message = record.getMessage().replace("\t", "").strip() + + # Check if the message is empty + if not record.message: + return + + super().emit(record) + + def format(self, record: LogRecord) -> str: + """Format function that adds colors to log level.""" + seperator = " " * (8 - len(record.levelname)) + if self.json: + log_fmt = "{lvl='%(levelname)s', time='%(asctime)s', msg='%(message)s'}" + else: + log_fmt = ( + f"{LOG_COLORS[record.levelname] if self.colored else ''}" + f"%(levelname)s {'%(asctime)s' if self.timestamps else ''}" + f"{LOG_COLORS['RESET'] if self.colored else ''}" + f": {seperator} %(message)s" + ) + formatter = logging.Formatter(log_fmt) + return formatter.format(record) + + +def update_console_handler(level: int, timestamps: bool, colored: bool) -> None: + """Update the logging handler.""" + for handler in logging.getLogger(LOGGER_NAME).handlers: + if isinstance(handler, ConsoleHandler): + handler.setLevel(level) + handler.timestamps = timestamps + handler.colored = colored + # Configure console logger -console_handler = logging.StreamHandler() -console_handler.setLevel(logging.DEBUG) -console_handler.setFormatter(DEFAULT_FORMATTER) +console_handler = ConsoleHandler( + timestamps=False, + json=False, + colored=True, +) +console_handler.setLevel(logging.INFO) FLOWER_LOGGER.addHandler(console_handler) @@ -103,11 +168,10 @@ def warn_experimental_feature(name: str) -> None: """Warn the user when they use an experimental feature.""" log( WARN, - """ - EXPERIMENTAL FEATURE: %s + """EXPERIMENTAL FEATURE: %s - This is an experimental feature. It could change significantly or be removed - entirely in future versions of Flower. + This is an experimental feature. It could change significantly or be removed + entirely in future versions of Flower. """, name, ) @@ -117,11 +181,10 @@ def warn_deprecated_feature(name: str) -> None: """Warn the user when they use a deprecated feature.""" log( WARN, - """ - DEPRECATED FEATURE: %s + """DEPRECATED FEATURE: %s - This is a deprecated feature. It will be removed - entirely in future versions of Flower. + This is a deprecated feature. It will be removed + entirely in future versions of Flower. """, name, ) diff --git a/src/py/flwr/common/message.py b/src/py/flwr/common/message.py new file mode 100644 index 000000000000..88cf750f1a94 --- /dev/null +++ b/src/py/flwr/common/message.py @@ -0,0 +1,323 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Message.""" + +from __future__ import annotations + +from dataclasses import dataclass + +from .record import RecordSet + + +@dataclass +class Metadata: # pylint: disable=too-many-instance-attributes + """A dataclass holding metadata associated with the current message. + + Parameters + ---------- + run_id : int + An identifier for the current run. + message_id : str + An identifier for the current message. + src_node_id : int + An identifier for the node sending this message. + dst_node_id : int + An identifier for the node receiving this message. + reply_to_message : str + An identifier for the message this message replies to. + group_id : str + An identifier for grouping messages. In some settings, + this is used as the FL round. + ttl : str + Time-to-live for this message. + message_type : str + A string that encodes the action to be executed on + the receiving end. + partition_id : Optional[int] + An identifier that can be used when loading a particular + data partition for a ClientApp. Making use of this identifier + is more relevant when conducting simulations. + """ + + _run_id: int + _message_id: str + _src_node_id: int + _dst_node_id: int + _reply_to_message: str + _group_id: str + _ttl: str + _message_type: str + _partition_id: int | None + + def __init__( # pylint: disable=too-many-arguments + self, + run_id: int, + message_id: str, + src_node_id: int, + dst_node_id: int, + reply_to_message: str, + group_id: str, + ttl: str, + message_type: str, + partition_id: int | None = None, + ) -> None: + self._run_id = run_id + self._message_id = message_id + self._src_node_id = src_node_id + self._dst_node_id = dst_node_id + self._reply_to_message = reply_to_message + self._group_id = group_id + self._ttl = ttl + self._message_type = message_type + self._partition_id = partition_id + + @property + def run_id(self) -> int: + """An identifier for the current run.""" + return self._run_id + + @property + def message_id(self) -> str: + """An identifier for the current message.""" + return self._message_id + + @property + def src_node_id(self) -> int: + """An identifier for the node sending this message.""" + return self._src_node_id + + @property + def reply_to_message(self) -> str: + """An identifier for the message this message replies to.""" + return self._reply_to_message + + @property + def dst_node_id(self) -> int: + """An identifier for the node receiving this message.""" + return self._dst_node_id + + @dst_node_id.setter + def dst_node_id(self, value: int) -> None: + """Set dst_node_id.""" + self._dst_node_id = value + + @property + def group_id(self) -> str: + """An identifier for grouping messages.""" + return self._group_id + + @group_id.setter + def group_id(self, value: str) -> None: + """Set group_id.""" + self._group_id = value + + @property + def ttl(self) -> str: + """Time-to-live for this message.""" + return self._ttl + + @ttl.setter + def ttl(self, value: str) -> None: + """Set ttl.""" + self._ttl = value + + @property + def message_type(self) -> str: + """A string that encodes the action to be executed on the receiving end.""" + return self._message_type + + @message_type.setter + def message_type(self, value: str) -> None: + """Set message_type.""" + self._message_type = value + + @property + def partition_id(self) -> int | None: + """An identifier telling which data partition a ClientApp should use.""" + return self._partition_id + + @partition_id.setter + def partition_id(self, value: int) -> None: + """Set patition_id.""" + self._partition_id = value + + +@dataclass +class Error: + """A dataclass that stores information about an error that occurred. + + Parameters + ---------- + code : int + An identifier for the error. + reason : Optional[str] + A reason for why the error arose (e.g. an exception stack-trace) + """ + + _code: int + _reason: str | None = None + + def __init__(self, code: int, reason: str | None = None) -> None: + self._code = code + self._reason = reason + + @property + def code(self) -> int: + """Error code.""" + return self._code + + @property + def reason(self) -> str | None: + """Reason reported about the error.""" + return self._reason + + +@dataclass +class Message: + """State of your application from the viewpoint of the entity using it. + + Parameters + ---------- + metadata : Metadata + A dataclass including information about the message to be executed. + content : Optional[RecordSet] + Holds records either sent by another entity (e.g. sent by the server-side + logic to a client, or vice-versa) or that will be sent to it. + error : Optional[Error] + A dataclass that captures information about an error that took place + when processing another message. + """ + + _metadata: Metadata + _content: RecordSet | None = None + _error: Error | None = None + + def __init__( + self, + metadata: Metadata, + content: RecordSet | None = None, + error: Error | None = None, + ) -> None: + self._metadata = metadata + + if not (content is None) ^ (error is None): + raise ValueError("Either `content` or `error` must be set, but not both.") + + self._content = content + self._error = error + + @property + def metadata(self) -> Metadata: + """A dataclass including information about the message to be executed.""" + return self._metadata + + @property + def content(self) -> RecordSet: + """The content of this message.""" + if self._content is None: + raise ValueError( + "Message content is None. Use .has_content() " + "to check if a message has content." + ) + return self._content + + @content.setter + def content(self, value: RecordSet) -> None: + """Set content.""" + if self._error is None: + self._content = value + else: + raise ValueError("A message with an error set cannot have content.") + + @property + def error(self) -> Error: + """Error captured by this message.""" + if self._error is None: + raise ValueError( + "Message error is None. Use .has_error() " + "to check first if a message carries an error." + ) + return self._error + + @error.setter + def error(self, value: Error) -> None: + """Set error.""" + if self.has_content(): + raise ValueError("A message with content set cannot carry an error.") + self._error = value + + def has_content(self) -> bool: + """Return True if message has content, else False.""" + return self._content is not None + + def has_error(self) -> bool: + """Return True if message has an error, else False.""" + return self._error is not None + + def _create_reply_metadata(self, ttl: str) -> Metadata: + """Construct metadata for a reply message.""" + return Metadata( + run_id=self.metadata.run_id, + message_id="", + src_node_id=self.metadata.dst_node_id, + dst_node_id=self.metadata.src_node_id, + reply_to_message=self.metadata.message_id, + group_id=self.metadata.group_id, + ttl=ttl, + message_type=self.metadata.message_type, + partition_id=self.metadata.partition_id, + ) + + def create_error_reply( + self, + error: Error, + ttl: str, + ) -> Message: + """Construct a reply message indicating an error happened. + + Parameters + ---------- + error : Error + The error that was encountered. + ttl : str + Time-to-live for this message. + """ + # Create reply with error + message = Message(metadata=self._create_reply_metadata(ttl), error=error) + return message + + def create_reply(self, content: RecordSet, ttl: str) -> Message: + """Create a reply to this message with specified content and TTL. + + The method generates a new `Message` as a reply to this message. + It inherits 'run_id', 'src_node_id', 'dst_node_id', and 'message_type' from + this message and sets 'reply_to_message' to the ID of this message. + + Parameters + ---------- + content : RecordSet + The content for the reply message. + ttl : str + Time-to-live for this message. + + Returns + ------- + Message + A new `Message` instance representing the reply. + """ + return Message( + metadata=self._create_reply_metadata(ttl), + content=content, + ) diff --git a/src/py/flwr/common/message_test.py b/src/py/flwr/common/message_test.py new file mode 100644 index 000000000000..ba628bb3235a --- /dev/null +++ b/src/py/flwr/common/message_test.py @@ -0,0 +1,109 @@ +# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Message tests.""" + + +from contextlib import ExitStack +from typing import Any, Callable + +import pytest + +# pylint: enable=E0611 +from . import RecordSet +from .message import Error, Message +from .serde_test import RecordMaker + + +@pytest.mark.parametrize( + "content_fn, error_fn, context", + [ + ( + lambda maker: maker.recordset(1, 1, 1), + None, + None, + ), # check when only content is set + (None, lambda code: Error(code=code), None), # check when only error is set + ( + lambda maker: maker.recordset(1, 1, 1), + lambda code: Error(code=code), + pytest.raises(ValueError), + ), # check when both are set (ERROR) + (None, None, pytest.raises(ValueError)), # check when neither is set (ERROR) + ], +) +def test_message_creation( + content_fn: Callable[ + [ + RecordMaker, + ], + RecordSet, + ], + error_fn: Callable[[int], Error], + context: Any, +) -> None: + """Test Message creation attempting to pass content and/or error.""" + # Prepare + maker = RecordMaker(state=2) + metadata = maker.metadata() + + with ExitStack() as stack: + if context: + stack.enter_context(context) + + _ = Message( + metadata=metadata, + content=None if content_fn is None else content_fn(maker), + error=None if error_fn is None else error_fn(0), + ) + + +def create_message_with_content() -> Message: + """Create a Message with content.""" + maker = RecordMaker(state=2) + metadata = maker.metadata() + return Message(metadata=metadata, content=RecordSet()) + + +def create_message_with_error() -> Message: + """Create a Message with error.""" + maker = RecordMaker(state=2) + metadata = maker.metadata() + return Message(metadata=metadata, error=Error(code=1)) + + +@pytest.mark.parametrize( + "message_creation_fn", + [ + create_message_with_content, + create_message_with_error, + ], +) +def test_altering_message( + message_creation_fn: Callable[ + [], + Message, + ], +) -> None: + """Test that a message with content doesn't allow setting an error. + + And viceversa. + """ + message = message_creation_fn() + + with pytest.raises(ValueError): + if message.has_content(): + message.error = Error(code=123) + if message.has_error(): + message.content = RecordSet() diff --git a/src/py/flwr/common/metricsrecord.py b/src/py/flwr/common/metricsrecord.py deleted file mode 100644 index e70b0cb31d55..000000000000 --- a/src/py/flwr/common/metricsrecord.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright 2024 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""MetricsRecord.""" - - -from dataclasses import dataclass, field -from typing import Dict, Optional, get_args - -from .typing import MetricsRecordValues, MetricsScalar - - -@dataclass -class MetricsRecord: - """Metrics record.""" - - data: Dict[str, MetricsRecordValues] = field(default_factory=dict) - - def __init__( - self, - metrics_dict: Optional[Dict[str, MetricsRecordValues]] = None, - keep_input: bool = True, - ): - """Construct a MetricsRecord object. - - Parameters - ---------- - metrics_dict : Optional[Dict[str, MetricsRecordValues]] - A dictionary that stores basic types (i.e. `int`, `float` as defined - in `MetricsScalar`) and list of such types (see `MetricsScalarList`). - keep_input : bool (default: True) - A boolean indicating whether metrics should be deleted from the input - dictionary immediately after adding them to the record. When set - to True, the data is duplicated in memory. If memory is a concern, set - it to False. - """ - self.data = {} - if metrics_dict: - self.set_metrics(metrics_dict, keep_input=keep_input) - - def set_metrics( - self, metrics_dict: Dict[str, MetricsRecordValues], keep_input: bool = True - ) -> None: - """Add metrics to the record. - - Parameters - ---------- - metrics_dict : Dict[str, MetricsRecordValues] - A dictionary that stores basic types (i.e. `int`, `float` as defined - in `MetricsScalar`) and list of such types (see `MetricsScalarList`). - keep_input : bool (default: True) - A boolean indicating whether metrics should be deleted from the input - dictionary immediately after adding them to the record. When set - to True, the data is duplicated in memory. If memory is a concern, set - it to False. - """ - if any(not isinstance(k, str) for k in metrics_dict.keys()): - raise TypeError(f"Not all keys are of valid type. Expected {str}.") - - def is_valid(value: MetricsScalar) -> None: - """Check if value is of expected type.""" - if not isinstance(value, get_args(MetricsScalar)) or isinstance( - value, bool - ): - raise TypeError( - "Not all values are of valid type." - f" Expected {MetricsRecordValues} but you passed {type(value)}." - ) - - # Check types of values - # Split between those values that are list and those that aren't - # then process in the same way - for value in metrics_dict.values(): - if isinstance(value, list): - # If your lists are large (e.g. 1M+ elements) this will be slow - # 1s to check 10M element list on a M2 Pro - # In such settings, you'd be better of treating such metric as - # an array and pass it to a ParametersRecord. - # Empty lists are valid - if len(value) > 0: - is_valid(value[0]) - # all elements in the list must be of the same valid type - # this is needed for protobuf - value_type = type(value[0]) - if not all(isinstance(v, value_type) for v in value): - raise TypeError( - "All values in a list must be of the same valid type. " - f"One of {MetricsScalar}." - ) - else: - is_valid(value) - - # Add metrics to record - if keep_input: - # Copy - self.data = metrics_dict.copy() - else: - # Add entries to dataclass without duplicating memory - for key in list(metrics_dict.keys()): - self.data[key] = metrics_dict[key] - del metrics_dict[key] - - def __getitem__(self, key: str) -> MetricsRecordValues: - """Retrieve an element stored in record.""" - return self.data[key] diff --git a/src/py/flwr/common/object_ref.py b/src/py/flwr/common/object_ref.py new file mode 100644 index 000000000000..4660f07e24a4 --- /dev/null +++ b/src/py/flwr/common/object_ref.py @@ -0,0 +1,140 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Helper functions to load objects from a reference.""" + + +import ast +import importlib +from importlib.util import find_spec +from typing import Any, Optional, Tuple, Type + +OBJECT_REF_HELP_STR = """ +\n\nThe object reference string should have the form :. Valid +examples include `client:app` and `project.package.module:wrapper.app`. It must +refer to a module on the PYTHONPATH and the module needs to have the specified +attribute. +""" + + +def validate( + module_attribute_str: str, +) -> Tuple[bool, Optional[str]]: + """Validate object reference. + + The object reference string should have the form :. Valid + examples include `client:app` and `project.package.module:wrapper.app`. It must + refer to a module on the PYTHONPATH and the module needs to have the specified + attribute. + + Returns + ------- + Tuple[bool, Optional[str]] + A boolean indicating whether an object reference is valid and + the reason why it might not be. + """ + module_str, _, attributes_str = module_attribute_str.partition(":") + if not module_str: + return ( + False, + f"Missing module in {module_attribute_str}{OBJECT_REF_HELP_STR}", + ) + if not attributes_str: + return ( + False, + f"Missing attribute in {module_attribute_str}{OBJECT_REF_HELP_STR}", + ) + + # Load module + module = find_spec(module_str) + if module and module.origin: + if not _find_attribute_in_module(module.origin, attributes_str): + return ( + False, + f"Unable to find attribute {attributes_str} in module {module_str}" + f"{OBJECT_REF_HELP_STR}", + ) + return (True, None) + + return ( + False, + f"Unable to load module {module_str}{OBJECT_REF_HELP_STR}", + ) + + +def load_app( + module_attribute_str: str, + error_type: Type[Exception], +) -> Any: + """Return the object specified in a module attribute string. + + The module/attribute string should have the form :. Valid + examples include `client:app` and `project.package.module:wrapper.app`. It must + refer to a module on the PYTHONPATH, the module needs to have the specified + attribute. + """ + valid, error_msg = validate(module_attribute_str) + if not valid and error_msg: + raise error_type(error_msg) from None + + module_str, _, attributes_str = module_attribute_str.partition(":") + + try: + module = importlib.import_module(module_str) + except ModuleNotFoundError: + raise error_type( + f"Unable to load module {module_str}{OBJECT_REF_HELP_STR}", + ) from None + + # Recursively load attribute + attribute = module + try: + for attribute_str in attributes_str.split("."): + attribute = getattr(attribute, attribute_str) + except AttributeError: + raise error_type( + f"Unable to load attribute {attributes_str} from module {module_str}" + f"{OBJECT_REF_HELP_STR}", + ) from None + + return attribute + + +def _find_attribute_in_module(file_path: str, attribute_name: str) -> bool: + """Check if attribute_name exists in module's abstract symbolic tree.""" + with open(file_path, encoding="utf-8") as file: + node = ast.parse(file.read(), filename=file_path) + + for n in ast.walk(node): + if isinstance(n, ast.Assign): + for target in n.targets: + if isinstance(target, ast.Name) and target.id == attribute_name: + return True + if _is_module_in_all(attribute_name, target, n): + return True + return False + + +def _is_module_in_all(attribute_name: str, target: ast.expr, n: ast.Assign) -> bool: + """Now check if attribute_name is in __all__.""" + if isinstance(target, ast.Name) and target.id == "__all__": + if isinstance(n.value, ast.List): + for elt in n.value.elts: + if isinstance(elt, ast.Str) and elt.s == attribute_name: + return True + elif isinstance(n.value, ast.Tuple): + for elt in n.value.elts: + if isinstance(elt, ast.Str) and elt.s == attribute_name: + return True + return False diff --git a/src/py/flwr/common/object_ref_test.py b/src/py/flwr/common/object_ref_test.py new file mode 100644 index 000000000000..f4513a4319ab --- /dev/null +++ b/src/py/flwr/common/object_ref_test.py @@ -0,0 +1,46 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the validation function of object refs.""" + +from .object_ref import OBJECT_REF_HELP_STR, validate + + +def test_validate_object_reference() -> None: + """Test that validate_object_reference succeeds correctly.""" + # Prepare + ref = "flwr.cli.run:run" + + # Execute + is_valid, error = validate(ref) + + # Assert + assert is_valid + assert error is None + + +def test_validate_object_reference_fails() -> None: + """Test that validate_object_reference fails correctly.""" + # Prepare + ref = "flwr.cli.run:runa" + + # Execute + is_valid, error = validate(ref) + + # Assert + assert not is_valid + assert ( + error + == f"Unable to find attribute runa in module flwr.cli.run{OBJECT_REF_HELP_STR}" + ) diff --git a/src/py/flwr/common/record/__init__.py b/src/py/flwr/common/record/__init__.py new file mode 100644 index 000000000000..60bc54b8552a --- /dev/null +++ b/src/py/flwr/common/record/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Record APIs.""" + +from .configsrecord import ConfigsRecord +from .conversion_utils import array_from_numpy +from .metricsrecord import MetricsRecord +from .parametersrecord import Array, ParametersRecord +from .recordset import RecordSet + +__all__ = [ + "Array", + "array_from_numpy", + "ConfigsRecord", + "MetricsRecord", + "ParametersRecord", + "RecordSet", +] diff --git a/src/py/flwr/common/record/configsrecord.py b/src/py/flwr/common/record/configsrecord.py new file mode 100644 index 000000000000..704657601f50 --- /dev/null +++ b/src/py/flwr/common/record/configsrecord.py @@ -0,0 +1,87 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""ConfigsRecord.""" + + +from typing import Dict, Optional, get_args + +from flwr.common.typing import ConfigsRecordValues, ConfigsScalar + +from .typeddict import TypedDict + + +def _check_key(key: str) -> None: + """Check if key is of expected type.""" + if not isinstance(key, str): + raise TypeError(f"Key must be of type `str` but `{type(key)}` was passed.") + + +def _check_value(value: ConfigsRecordValues) -> None: + def is_valid(__v: ConfigsScalar) -> None: + """Check if value is of expected type.""" + if not isinstance(__v, get_args(ConfigsScalar)): + raise TypeError( + "Not all values are of valid type." + f" Expected `{ConfigsRecordValues}` but `{type(__v)}` was passed." + ) + + if isinstance(value, list): + # If your lists are large (e.g. 1M+ elements) this will be slow + # 1s to check 10M element list on a M2 Pro + # In such settings, you'd be better of treating such config as + # an array and pass it to a ParametersRecord. + # Empty lists are valid + if len(value) > 0: + is_valid(value[0]) + # all elements in the list must be of the same valid type + # this is needed for protobuf + value_type = type(value[0]) + if not all(isinstance(v, value_type) for v in value): + raise TypeError( + "All values in a list must be of the same valid type. " + f"One of {ConfigsScalar}." + ) + else: + is_valid(value) + + +class ConfigsRecord(TypedDict[str, ConfigsRecordValues]): + """Configs record.""" + + def __init__( + self, + configs_dict: Optional[Dict[str, ConfigsRecordValues]] = None, + keep_input: bool = True, + ) -> None: + """Construct a ConfigsRecord object. + + Parameters + ---------- + configs_dict : Optional[Dict[str, ConfigsRecordValues]] + A dictionary that stores basic types (i.e. `str`, `int`, `float`, `bytes` as + defined in `ConfigsScalar`) and lists of such types (see + `ConfigsScalarList`). + keep_input : bool (default: True) + A boolean indicating whether config passed should be deleted from the input + dictionary immediately after adding them to the record. When set + to True, the data is duplicated in memory. If memory is a concern, set + it to False. + """ + super().__init__(_check_key, _check_value) + if configs_dict: + for k in list(configs_dict.keys()): + self[k] = configs_dict[k] + if not keep_input: + del configs_dict[k] diff --git a/src/py/flwr/common/record/conversion_utils.py b/src/py/flwr/common/record/conversion_utils.py new file mode 100644 index 000000000000..7cc0b04283e9 --- /dev/null +++ b/src/py/flwr/common/record/conversion_utils.py @@ -0,0 +1,40 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Conversion utility functions for Records.""" + + +from io import BytesIO + +import numpy as np + +from ..constant import SType +from ..typing import NDArray +from .parametersrecord import Array + + +def array_from_numpy(ndarray: NDArray) -> Array: + """Create Array from NumPy ndarray.""" + buffer = BytesIO() + # WARNING: NEVER set allow_pickle to true. + # Reason: loading pickled data can execute arbitrary code + # Source: https://numpy.org/doc/stable/reference/generated/numpy.save.html + np.save(buffer, ndarray, allow_pickle=False) + data = buffer.getvalue() + return Array( + dtype=str(ndarray.dtype), + shape=list(ndarray.shape), + stype=SType.NUMPY, + data=data, + ) diff --git a/src/py/flwr/common/record/conversion_utils_test.py b/src/py/flwr/common/record/conversion_utils_test.py new file mode 100644 index 000000000000..84be37fda4a3 --- /dev/null +++ b/src/py/flwr/common/record/conversion_utils_test.py @@ -0,0 +1,44 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit tests for conversion utility functions.""" + + +import unittest +from io import BytesIO + +import numpy as np + +from ..constant import SType +from .conversion_utils import array_from_numpy + + +class TestArrayFromNumpy(unittest.TestCase): + """Unit tests for array_from_numpy.""" + + def test_array_from_numpy(self) -> None: + """Test the array_from_numpy function.""" + # Prepare + original_array = np.array([1, 2, 3], dtype=np.float32) + + # Execute + array_instance = array_from_numpy(original_array) + buffer = BytesIO(array_instance.data) + deserialized_array = np.load(buffer, allow_pickle=False) + + # Assert + self.assertEqual(array_instance.dtype, str(original_array.dtype)) + self.assertEqual(array_instance.shape, list(original_array.shape)) + self.assertEqual(array_instance.stype, SType.NUMPY) + np.testing.assert_array_equal(deserialized_array, original_array) diff --git a/src/py/flwr/common/record/metricsrecord.py b/src/py/flwr/common/record/metricsrecord.py new file mode 100644 index 000000000000..81b02303421b --- /dev/null +++ b/src/py/flwr/common/record/metricsrecord.py @@ -0,0 +1,86 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""MetricsRecord.""" + + +from typing import Dict, Optional, get_args + +from flwr.common.typing import MetricsRecordValues, MetricsScalar + +from .typeddict import TypedDict + + +def _check_key(key: str) -> None: + """Check if key is of expected type.""" + if not isinstance(key, str): + raise TypeError(f"Key must be of type `str` but `{type(key)}` was passed.") + + +def _check_value(value: MetricsRecordValues) -> None: + def is_valid(__v: MetricsScalar) -> None: + """Check if value is of expected type.""" + if not isinstance(__v, get_args(MetricsScalar)) or isinstance(__v, bool): + raise TypeError( + "Not all values are of valid type." + f" Expected `{MetricsRecordValues}` but `{type(__v)}` was passed." + ) + + if isinstance(value, list): + # If your lists are large (e.g. 1M+ elements) this will be slow + # 1s to check 10M element list on a M2 Pro + # In such settings, you'd be better of treating such metric as + # an array and pass it to a ParametersRecord. + # Empty lists are valid + if len(value) > 0: + is_valid(value[0]) + # all elements in the list must be of the same valid type + # this is needed for protobuf + value_type = type(value[0]) + if not all(isinstance(v, value_type) for v in value): + raise TypeError( + "All values in a list must be of the same valid type. " + f"One of {MetricsScalar}." + ) + else: + is_valid(value) + + +class MetricsRecord(TypedDict[str, MetricsRecordValues]): + """Metrics record.""" + + def __init__( + self, + metrics_dict: Optional[Dict[str, MetricsRecordValues]] = None, + keep_input: bool = True, + ): + """Construct a MetricsRecord object. + + Parameters + ---------- + metrics_dict : Optional[Dict[str, MetricsRecordValues]] + A dictionary that stores basic types (i.e. `int`, `float` as defined + in `MetricsScalar`) and list of such types (see `MetricsScalarList`). + keep_input : bool (default: True) + A boolean indicating whether metrics should be deleted from the input + dictionary immediately after adding them to the record. When set + to True, the data is duplicated in memory. If memory is a concern, set + it to False. + """ + super().__init__(_check_key, _check_value) + if metrics_dict: + for k in list(metrics_dict.keys()): + self[k] = metrics_dict[k] + if not keep_input: + del metrics_dict[k] diff --git a/src/py/flwr/common/parametersrecord.py b/src/py/flwr/common/record/parametersrecord.py similarity index 64% rename from src/py/flwr/common/parametersrecord.py rename to src/py/flwr/common/record/parametersrecord.py index ef02a0789ddf..17bf3f608db7 100644 --- a/src/py/flwr/common/parametersrecord.py +++ b/src/py/flwr/common/record/parametersrecord.py @@ -14,9 +14,15 @@ # ============================================================================== """ParametersRecord and Array.""" +from dataclasses import dataclass +from io import BytesIO +from typing import List, Optional, OrderedDict, cast -from dataclasses import dataclass, field -from typing import List, Optional, OrderedDict +import numpy as np + +from ..constant import SType +from ..typing import NDArray +from .typeddict import TypedDict @dataclass @@ -49,9 +55,35 @@ class Array: stype: str data: bytes + def numpy(self) -> NDArray: + """Return the array as a NumPy array.""" + if self.stype != SType.NUMPY: + raise TypeError( + f"Unsupported serialization type for numpy conversion: '{self.stype}'" + ) + bytes_io = BytesIO(self.data) + # WARNING: NEVER set allow_pickle to true. + # Reason: loading pickled data can execute arbitrary code + # Source: https://numpy.org/doc/stable/reference/generated/numpy.load.html + ndarray_deserialized = np.load(bytes_io, allow_pickle=False) + return cast(NDArray, ndarray_deserialized) + + +def _check_key(key: str) -> None: + """Check if key is of expected type.""" + if not isinstance(key, str): + raise TypeError(f"Key must be of type `str` but `{type(key)}` was passed.") + + +def _check_value(value: Array) -> None: + if not isinstance(value, Array): + raise TypeError( + f"Value must be of type `{Array}` but `{type(value)}` was passed." + ) + @dataclass -class ParametersRecord: +class ParametersRecord(TypedDict[str, Array]): """Parameters record. A dataclass storing named Arrays in order. This means that it holds entries as an @@ -59,8 +91,6 @@ class ParametersRecord: PyTorch's state_dict, but holding serialised tensors instead. """ - data: OrderedDict[str, Array] = field(default_factory=OrderedDict[str, Array]) - def __init__( self, array_dict: Optional[OrderedDict[str, Array]] = None, @@ -81,37 +111,9 @@ def __init__( parameters after adding it to the record, set this flag to True. When set to True, the data is duplicated in memory. """ - self.data = OrderedDict() + super().__init__(_check_key, _check_value) if array_dict: - self.set_parameters(array_dict, keep_input=keep_input) - - def set_parameters( - self, array_dict: OrderedDict[str, Array], keep_input: bool = False - ) -> None: - """Add parameters to record. - - Parameters - ---------- - array_dict : OrderedDict[str, Array] - A dictionary that stores serialized array-like or tensor-like objects. - keep_input : bool (default: False) - A boolean indicating whether parameters should be deleted from the input - dictionary immediately after adding them to the record. - """ - if any(not isinstance(k, str) for k in array_dict.keys()): - raise TypeError(f"Not all keys are of valid type. Expected {str}") - if any(not isinstance(v, Array) for v in array_dict.values()): - raise TypeError(f"Not all values are of valid type. Expected {Array}") - - if keep_input: - # Copy - self.data = OrderedDict(array_dict) - else: - # Add entries to dataclass without duplicating memory - for key in list(array_dict.keys()): - self.data[key] = array_dict[key] - del array_dict[key] - - def __getitem__(self, key: str) -> Array: - """Retrieve an element stored in record.""" - return self.data[key] + for k in list(array_dict.keys()): + self[k] = array_dict[k] + if not keep_input: + del array_dict[k] diff --git a/src/py/flwr/common/record/parametersrecord_test.py b/src/py/flwr/common/record/parametersrecord_test.py new file mode 100644 index 000000000000..9633af7bda6d --- /dev/null +++ b/src/py/flwr/common/record/parametersrecord_test.py @@ -0,0 +1,62 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit tests for ParametersRecord and Array.""" + + +import unittest +from io import BytesIO + +import numpy as np + +from ..constant import SType +from .parametersrecord import Array + + +class TestArray(unittest.TestCase): + """Unit tests for Array.""" + + def test_numpy_conversion_valid(self) -> None: + """Test the numpy method with valid Array instance.""" + # Prepare + original_array = np.array([1, 2, 3], dtype=np.float32) + buffer = BytesIO() + np.save(buffer, original_array, allow_pickle=False) + buffer.seek(0) + + # Execute + array_instance = Array( + dtype=str(original_array.dtype), + shape=list(original_array.shape), + stype=SType.NUMPY, + data=buffer.read(), + ) + converted_array = array_instance.numpy() + + # Assert + np.testing.assert_array_equal(converted_array, original_array) + + def test_numpy_conversion_invalid(self) -> None: + """Test the numpy method with invalid Array instance.""" + # Prepare + array_instance = Array( + dtype="float32", + shape=[3], + stype="invalid_stype", # Non-numpy stype + data=b"", + ) + + # Execute and assert + with self.assertRaises(TypeError): + array_instance.numpy() diff --git a/src/py/flwr/common/record/recordset.py b/src/py/flwr/common/record/recordset.py new file mode 100644 index 000000000000..d8ef44ab15c2 --- /dev/null +++ b/src/py/flwr/common/record/recordset.py @@ -0,0 +1,79 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""RecordSet.""" + + +from dataclasses import dataclass +from typing import Callable, Dict, Optional, Type, TypeVar + +from .configsrecord import ConfigsRecord +from .metricsrecord import MetricsRecord +from .parametersrecord import ParametersRecord +from .typeddict import TypedDict + +T = TypeVar("T") + + +@dataclass +class RecordSet: + """RecordSet stores groups of parameters, metrics and configs.""" + + _parameters_records: TypedDict[str, ParametersRecord] + _metrics_records: TypedDict[str, MetricsRecord] + _configs_records: TypedDict[str, ConfigsRecord] + + def __init__( + self, + parameters_records: Optional[Dict[str, ParametersRecord]] = None, + metrics_records: Optional[Dict[str, MetricsRecord]] = None, + configs_records: Optional[Dict[str, ConfigsRecord]] = None, + ) -> None: + def _get_check_fn(__t: Type[T]) -> Callable[[T], None]: + def _check_fn(__v: T) -> None: + if not isinstance(__v, __t): + raise TypeError(f"Expected `{__t}`, but `{type(__v)}` was passed.") + + return _check_fn + + self._parameters_records = TypedDict[str, ParametersRecord]( + _get_check_fn(str), _get_check_fn(ParametersRecord) + ) + self._metrics_records = TypedDict[str, MetricsRecord]( + _get_check_fn(str), _get_check_fn(MetricsRecord) + ) + self._configs_records = TypedDict[str, ConfigsRecord]( + _get_check_fn(str), _get_check_fn(ConfigsRecord) + ) + if parameters_records is not None: + self._parameters_records.update(parameters_records) + if metrics_records is not None: + self._metrics_records.update(metrics_records) + if configs_records is not None: + self._configs_records.update(configs_records) + + @property + def parameters_records(self) -> TypedDict[str, ParametersRecord]: + """Dictionary holding ParametersRecord instances.""" + return self._parameters_records + + @property + def metrics_records(self) -> TypedDict[str, MetricsRecord]: + """Dictionary holding MetricsRecord instances.""" + return self._metrics_records + + @property + def configs_records(self) -> TypedDict[str, ConfigsRecord]: + """Dictionary holding ConfigsRecord instances.""" + return self._configs_records diff --git a/src/py/flwr/common/recordset_test.py b/src/py/flwr/common/record/recordset_test.py similarity index 84% rename from src/py/flwr/common/recordset_test.py rename to src/py/flwr/common/record/recordset_test.py index e1825eaeef14..bcf5c75a1e02 100644 --- a/src/py/flwr/common/recordset_test.py +++ b/src/py/flwr/common/record/recordset_test.py @@ -14,20 +14,18 @@ # ============================================================================== """RecordSet tests.""" +from copy import deepcopy from typing import Callable, Dict, List, OrderedDict, Type, Union import numpy as np import pytest -from .configsrecord import ConfigsRecord -from .metricsrecord import MetricsRecord -from .parameter import ndarrays_to_parameters, parameters_to_ndarrays -from .parametersrecord import Array, ParametersRecord -from .recordset_compat import ( +from flwr.common.parameter import ndarrays_to_parameters, parameters_to_ndarrays +from flwr.common.recordset_compat import ( parameters_to_parametersrecord, parametersrecord_to_parameters, ) -from .typing import ( +from flwr.common.typing import ( ConfigsRecordValues, MetricsRecordValues, NDArray, @@ -35,6 +33,8 @@ Parameters, ) +from . import Array, ConfigsRecord, MetricsRecord, ParametersRecord + def get_ndarrays() -> NDArrays: """Return list of NumPy arrays.""" @@ -70,7 +70,7 @@ def test_parameters_to_array_and_back() -> None: """Test conversion between legacy Parameters and Array.""" ndarrays = get_ndarrays() - # Array represents a single array, unlike Paramters, which represent a + # Array represents a single array, unlike parameters, which represent a # list of arrays ndarray = ndarrays[0] @@ -87,34 +87,50 @@ def test_parameters_to_array_and_back() -> None: assert np.array_equal(ndarray, ndarray_) -def test_parameters_to_parametersrecord_and_back() -> None: +@pytest.mark.parametrize( + "keep_input, validate_freed_fn", + [ + (False, lambda x, x_copy, y: len(x.tensors) == 0), # check tensors were freed + (True, lambda x, x_copy, y: x.tensors == y.tensors), # check they are equal + ], +) +def test_parameters_to_parametersrecord_and_back( + keep_input: bool, + validate_freed_fn: Callable[[Parameters, Parameters, Parameters], bool], +) -> None: """Test conversion between legacy Parameters and ParametersRecords.""" ndarrays = get_ndarrays() parameters = ndarrays_to_parameters(ndarrays) + parameters_copy = deepcopy(parameters) - params_record = parameters_to_parametersrecord(parameters=parameters) + params_record = parameters_to_parametersrecord( + parameters=parameters, keep_input=keep_input + ) - parameters_ = parametersrecord_to_parameters(params_record) + parameters_ = parametersrecord_to_parameters(params_record, keep_input=keep_input) ndarrays_ = parameters_to_ndarrays(parameters=parameters_) + # Validate returned NDArrays match those at the beginning for arr, arr_ in zip(ndarrays, ndarrays_): - assert np.array_equal(arr, arr_) + assert np.array_equal(arr, arr_), "no" + + # Validate initial Parameters object has been handled according to `keep_input` + assert validate_freed_fn(parameters, parameters_copy, parameters_) def test_set_parameters_while_keeping_intputs() -> None: """Tests keep_input functionality in ParametersRecord.""" # Adding parameters to a record that doesn't erase entries in the input `array_dict` - p_record = ParametersRecord(keep_input=True) array_dict = OrderedDict( {str(i): ndarray_to_array(ndarray) for i, ndarray in enumerate(get_ndarrays())} ) - p_record.set_parameters(array_dict, keep_input=True) + p_record = ParametersRecord(array_dict, keep_input=True) # Creating a second parametersrecord passing the same `array_dict` (not erased) p_record_2 = ParametersRecord(array_dict) - assert p_record.data == p_record_2.data + assert p_record == p_record_2 # Now it should be empty (the second ParametersRecord wasn't flagged to keep it) assert len(array_dict) == 0 @@ -126,7 +142,7 @@ def test_set_parameters_with_correct_types() -> None: array_dict = OrderedDict( {str(i): ndarray_to_array(ndarray) for i, ndarray in enumerate(get_ndarrays())} ) - p_record.set_parameters(array_dict) + p_record.update(array_dict) @pytest.mark.parametrize( @@ -151,7 +167,7 @@ def test_set_parameters_with_incorrect_types( } with pytest.raises(TypeError): - p_record.set_parameters(array_dict) # type: ignore + p_record.update(array_dict) @pytest.mark.parametrize( @@ -179,10 +195,10 @@ def test_set_metrics_to_metricsrecord_with_correct_types( ) # Add metric - m_record.set_metrics(my_metrics) + m_record.update(my_metrics) # Check metrics are actually added - assert my_metrics == m_record.data + assert my_metrics == m_record @pytest.mark.parametrize( @@ -232,7 +248,7 @@ def test_set_metrics_to_metricsrecord_with_incorrect_types( ) with pytest.raises(TypeError): - m_record.set_metrics(my_metrics) # type: ignore + m_record.update(my_metrics) @pytest.mark.parametrize( @@ -246,8 +262,6 @@ def test_set_metrics_to_metricsrecord_with_and_without_keeping_input( keep_input: bool, ) -> None: """Test keep_input functionality for MetricsRecord.""" - m_record = MetricsRecord(keep_input=keep_input) - # constructing a valid input labels = [1, 2.0] arrays = get_ndarrays() @@ -258,14 +272,14 @@ def test_set_metrics_to_metricsrecord_with_and_without_keeping_input( my_metrics_copy = my_metrics.copy() # Add metric - m_record.set_metrics(my_metrics, keep_input=keep_input) + m_record = MetricsRecord(my_metrics, keep_input=keep_input) # Check metrics are actually added # Check that input dict has been emptied when enabled such behaviour if keep_input: - assert my_metrics == m_record.data + assert my_metrics == m_record else: - assert my_metrics_copy == m_record.data + assert my_metrics_copy == m_record assert len(my_metrics) == 0 @@ -282,7 +296,7 @@ def test_set_metrics_to_metricsrecord_with_and_without_keeping_input( (str, lambda x: x.flatten().astype("float").tolist()), # str: List[float] (str, lambda x: x.flatten().astype("bool").tolist()), # str: List[bool] (str, lambda x: [x.flatten().tobytes()]), # str: List[bytes] - (str, lambda x: []), # str: empyt list + (str, lambda x: []), # str: emptyt list ], ) def test_set_configs_to_configsrecord_with_correct_types( @@ -300,7 +314,7 @@ def test_set_configs_to_configsrecord_with_correct_types( c_record = ConfigsRecord(my_configs) # check values are actually there - assert c_record.data == my_configs + assert c_record == my_configs @pytest.mark.parametrize( @@ -334,14 +348,14 @@ def test_set_configs_to_configsrecord_with_incorrect_types( value_fn: Callable[[NDArray], Union[NDArray, Dict[str, NDArray], List[float]]], ) -> None: """Test adding configs of various unsupported types to a ConfigsRecord.""" - m_record = ConfigsRecord() + c_record = ConfigsRecord() labels = [1, 2.0] arrays = get_ndarrays() - my_metrics = OrderedDict( + my_configs = OrderedDict( {key_type(label): value_fn(arr) for label, arr in zip(labels, arrays)} ) with pytest.raises(TypeError): - m_record.set_configs(my_metrics) # type: ignore + c_record.update(my_configs) diff --git a/src/py/flwr/common/record/typeddict.py b/src/py/flwr/common/record/typeddict.py new file mode 100644 index 000000000000..23d70dc4f7e8 --- /dev/null +++ b/src/py/flwr/common/record/typeddict.py @@ -0,0 +1,113 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Typed dict base class for *Records.""" + + +from typing import Any, Callable, Dict, Generic, Iterator, Tuple, TypeVar, cast + +K = TypeVar("K") # Key type +V = TypeVar("V") # Value type + + +class TypedDict(Generic[K, V]): + """Typed dictionary.""" + + def __init__( + self, check_key_fn: Callable[[K], None], check_value_fn: Callable[[V], None] + ): + self._data: Dict[K, V] = {} + self._check_key_fn = check_key_fn + self._check_value_fn = check_value_fn + + def __setitem__(self, key: K, value: V) -> None: + """Set the given key to the given value after type checking.""" + # Check the types of key and value + self._check_key_fn(key) + self._check_value_fn(value) + # Set key-value pair + self._data[key] = value + + def __delitem__(self, key: K) -> None: + """Remove the item with the specified key.""" + del self._data[key] + + def __getitem__(self, item: K) -> V: + """Return the value for the specified key.""" + return self._data[item] + + def __iter__(self) -> Iterator[K]: + """Yield an iterator over the keys of the dictionary.""" + return iter(self._data) + + def __repr__(self) -> str: + """Return a string representation of the dictionary.""" + return self._data.__repr__() + + def __len__(self) -> int: + """Return the number of items in the dictionary.""" + return len(self._data) + + def __contains__(self, key: K) -> bool: + """Check if the dictionary contains the specified key.""" + return key in self._data + + def __eq__(self, other: object) -> bool: + """Compare this instance to another dictionary or TypedDict.""" + if isinstance(other, TypedDict): + return self._data == other._data + if isinstance(other, dict): + return self._data == other + return NotImplemented + + def items(self) -> Iterator[Tuple[K, V]]: + """R.items() -> a set-like object providing a view on R's items.""" + return cast(Iterator[Tuple[K, V]], self._data.items()) + + def keys(self) -> Iterator[K]: + """R.keys() -> a set-like object providing a view on R's keys.""" + return cast(Iterator[K], self._data.keys()) + + def values(self) -> Iterator[V]: + """R.values() -> an object providing a view on R's values.""" + return cast(Iterator[V], self._data.values()) + + def update(self, *args: Any, **kwargs: Any) -> None: + """R.update([E, ]**F) -> None. + + Update R from dict/iterable E and F. + """ + for key, value in dict(*args, **kwargs).items(): + self[key] = value + + def pop(self, key: K) -> V: + """R.pop(k[,d]) -> v, remove specified key and return the corresponding value. + + If key is not found, d is returned if given, otherwise KeyError is raised. + """ + return self._data.pop(key) + + def get(self, key: K, default: V) -> V: + """R.get(k[,d]) -> R[k] if k in R, else d. + + d defaults to None. + """ + return self._data.get(key, default) + + def clear(self) -> None: + """R.clear() -> None. + + Remove all items from R. + """ + self._data.clear() diff --git a/src/py/flwr/common/recordset.py b/src/py/flwr/common/recordset.py deleted file mode 100644 index 61c880c970b8..000000000000 --- a/src/py/flwr/common/recordset.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright 2024 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""RecordSet.""" - - -from dataclasses import dataclass, field -from typing import Dict - -from .configsrecord import ConfigsRecord -from .metricsrecord import MetricsRecord -from .parametersrecord import ParametersRecord - - -@dataclass -class RecordSet: - """Definition of RecordSet.""" - - parameters: Dict[str, ParametersRecord] = field(default_factory=dict) - metrics: Dict[str, MetricsRecord] = field(default_factory=dict) - configs: Dict[str, ConfigsRecord] = field(default_factory=dict) - - def set_parameters(self, name: str, record: ParametersRecord) -> None: - """Add a ParametersRecord.""" - self.parameters[name] = record - - def get_parameters(self, name: str) -> ParametersRecord: - """Get a ParametesRecord.""" - return self.parameters[name] - - def del_parameters(self, name: str) -> None: - """Delete a ParametersRecord.""" - del self.parameters[name] - - def set_metrics(self, name: str, record: MetricsRecord) -> None: - """Add a MetricsRecord.""" - self.metrics[name] = record - - def get_metrics(self, name: str) -> MetricsRecord: - """Get a MetricsRecord.""" - return self.metrics[name] - - def del_metrics(self, name: str) -> None: - """Delete a MetricsRecord.""" - del self.metrics[name] - - def set_configs(self, name: str, record: ConfigsRecord) -> None: - """Add a ConfigsRecord.""" - self.configs[name] = record - - def get_configs(self, name: str) -> ConfigsRecord: - """Get a ConfigsRecord.""" - return self.configs[name] - - def del_configs(self, name: str) -> None: - """Delete a ConfigsRecord.""" - del self.configs[name] diff --git a/src/py/flwr/common/recordset_compat.py b/src/py/flwr/common/recordset_compat.py index c45f7fcd9fb8..394ea1353bab 100644 --- a/src/py/flwr/common/recordset_compat.py +++ b/src/py/flwr/common/recordset_compat.py @@ -17,10 +17,7 @@ from typing import Dict, Mapping, OrderedDict, Tuple, Union, cast, get_args -from .configsrecord import ConfigsRecord -from .metricsrecord import MetricsRecord -from .parametersrecord import Array, ParametersRecord -from .recordset import RecordSet +from . import Array, ConfigsRecord, MetricsRecord, ParametersRecord, RecordSet from .typing import ( Code, ConfigsRecordValues, @@ -40,26 +37,28 @@ def parametersrecord_to_parameters( - record: ParametersRecord, keep_input: bool = False + record: ParametersRecord, keep_input: bool ) -> Parameters: """Convert ParameterRecord to legacy Parameters. - Warning: Because `Arrays` in `ParametersRecord` encode more information of the + Warnings + -------- + Because `Arrays` in `ParametersRecord` encode more information of the array-like or tensor-like data (e.g their datatype, shape) than `Parameters` it might not be possible to reconstruct such data structures from `Parameters` objects - alone. Additional information or metadta must be provided from elsewhere. + alone. Additional information or metadata must be provided from elsewhere. Parameters ---------- record : ParametersRecord The record to be conveted into Parameters. - keep_input : bool (default: False) + keep_input : bool A boolean indicating whether entries in the record should be deleted from the input dictionary immediately after adding them to the record. """ parameters = Parameters(tensors=[], tensor_type="") - for key in list(record.data.keys()): + for key in list(record.keys()): parameters.tensors.append(record[key].data) if not parameters.tensor_type: @@ -68,13 +67,13 @@ def parametersrecord_to_parameters( parameters.tensor_type = record[key].stype if not keep_input: - del record.data[key] + del record[key] return parameters def parameters_to_parametersrecord( - parameters: Parameters, keep_input: bool = False + parameters: Parameters, keep_input: bool ) -> ParametersRecord: """Convert legacy Parameters into a single ParametersRecord. @@ -86,28 +85,25 @@ def parameters_to_parametersrecord( ---------- parameters : Parameters Parameters object to be represented as a ParametersRecord. - keep_input : bool (default: False) + keep_input : bool A boolean indicating whether parameters should be deleted from the input Parameters object (i.e. a list of serialized NumPy arrays) immediately after adding them to the record. """ tensor_type = parameters.tensor_type - p_record = ParametersRecord() - num_arrays = len(parameters.tensors) + ordered_dict = OrderedDict() for idx in range(num_arrays): if keep_input: tensor = parameters.tensors[idx] else: tensor = parameters.tensors.pop(0) - p_record.set_parameters( - OrderedDict( - {str(idx): Array(data=tensor, dtype="", stype=tensor_type, shape=[])} - ) + ordered_dict[str(idx)] = Array( + data=tensor, dtype="", stype=tensor_type, shape=[] ) - return p_record + return ParametersRecord(ordered_dict, keep_input=keep_input) def _check_mapping_from_recordscalartype_to_scalar( @@ -133,16 +129,16 @@ def _recordset_to_fit_or_evaluate_ins_components( ) -> Tuple[Parameters, Dict[str, Scalar]]: """Derive Fit/Evaluate Ins from a RecordSet.""" # get Array and construct Parameters - parameters_record = recordset.get_parameters(f"{ins_str}.parameters") + parameters_record = recordset.parameters_records[f"{ins_str}.parameters"] parameters = parametersrecord_to_parameters( parameters_record, keep_input=keep_input ) # get config dict - config_record = recordset.get_configs(f"{ins_str}.config") - - config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record.data) + config_record = recordset.configs_records[f"{ins_str}.config"] + # pylint: disable-next=protected-access + config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record._data) return parameters, config_dict @@ -153,13 +149,11 @@ def _fit_or_evaluate_ins_to_recordset( recordset = RecordSet() ins_str = "fitins" if isinstance(ins, FitIns) else "evaluateins" - recordset.set_parameters( - name=f"{ins_str}.parameters", - record=parameters_to_parametersrecord(ins.parameters, keep_input=keep_input), - ) + parametersrecord = parameters_to_parametersrecord(ins.parameters, keep_input) + recordset.parameters_records[f"{ins_str}.parameters"] = parametersrecord - recordset.set_configs( - name=f"{ins_str}.config", record=ConfigsRecord(ins.config) # type: ignore + recordset.configs_records[f"{ins_str}.config"] = ConfigsRecord( + ins.config # type: ignore ) return recordset @@ -174,12 +168,12 @@ def _embed_status_into_recordset( } # we add it to a `ConfigsRecord`` because the `status.message`` is a string # and `str` values aren't supported in `MetricsRecords` - recordset.set_configs(f"{res_str}.status", record=ConfigsRecord(status_dict)) + recordset.configs_records[f"{res_str}.status"] = ConfigsRecord(status_dict) return recordset def _extract_status_from_recordset(res_str: str, recordset: RecordSet) -> Status: - status = recordset.get_configs(f"{res_str}.status") + status = recordset.configs_records[f"{res_str}.status"] code = cast(int, status["code"]) return Status(code=Code(code), message=str(status["message"])) @@ -204,15 +198,15 @@ def recordset_to_fitres(recordset: RecordSet, keep_input: bool) -> FitRes: """Derive FitRes from a RecordSet object.""" ins_str = "fitres" parameters = parametersrecord_to_parameters( - recordset.get_parameters(f"{ins_str}.parameters"), keep_input=keep_input + recordset.parameters_records[f"{ins_str}.parameters"], keep_input=keep_input ) num_examples = cast( - int, recordset.get_metrics(f"{ins_str}.num_examples")["num_examples"] + int, recordset.metrics_records[f"{ins_str}.num_examples"]["num_examples"] ) - configs_record = recordset.get_configs(f"{ins_str}.metrics") - - metrics = _check_mapping_from_recordscalartype_to_scalar(configs_record.data) + configs_record = recordset.configs_records[f"{ins_str}.metrics"] + # pylint: disable-next=protected-access + metrics = _check_mapping_from_recordscalartype_to_scalar(configs_record._data) status = _extract_status_from_recordset(ins_str, recordset) return FitRes( @@ -226,16 +220,17 @@ def fitres_to_recordset(fitres: FitRes, keep_input: bool) -> RecordSet: res_str = "fitres" - recordset.set_configs( - name=f"{res_str}.metrics", record=ConfigsRecord(fitres.metrics) # type: ignore + recordset.configs_records[f"{res_str}.metrics"] = ConfigsRecord( + fitres.metrics # type: ignore ) - recordset.set_metrics( - name=f"{res_str}.num_examples", - record=MetricsRecord({"num_examples": fitres.num_examples}), + recordset.metrics_records[f"{res_str}.num_examples"] = MetricsRecord( + {"num_examples": fitres.num_examples}, ) - recordset.set_parameters( - name=f"{res_str}.parameters", - record=parameters_to_parametersrecord(fitres.parameters, keep_input), + recordset.parameters_records[f"{res_str}.parameters"] = ( + parameters_to_parametersrecord( + fitres.parameters, + keep_input, + ) ) # status @@ -264,14 +259,15 @@ def recordset_to_evaluateres(recordset: RecordSet) -> EvaluateRes: """Derive EvaluateRes from a RecordSet object.""" ins_str = "evaluateres" - loss = cast(int, recordset.get_metrics(f"{ins_str}.loss")["loss"]) + loss = cast(int, recordset.metrics_records[f"{ins_str}.loss"]["loss"]) num_examples = cast( - int, recordset.get_metrics(f"{ins_str}.num_examples")["num_examples"] + int, recordset.metrics_records[f"{ins_str}.num_examples"]["num_examples"] ) - configs_record = recordset.get_configs(f"{ins_str}.metrics") + configs_record = recordset.configs_records[f"{ins_str}.metrics"] - metrics = _check_mapping_from_recordscalartype_to_scalar(configs_record.data) + # pylint: disable-next=protected-access + metrics = _check_mapping_from_recordscalartype_to_scalar(configs_record._data) status = _extract_status_from_recordset(ins_str, recordset) return EvaluateRes( @@ -285,21 +281,18 @@ def evaluateres_to_recordset(evaluateres: EvaluateRes) -> RecordSet: res_str = "evaluateres" # loss - recordset.set_metrics( - name=f"{res_str}.loss", - record=MetricsRecord({"loss": evaluateres.loss}), + recordset.metrics_records[f"{res_str}.loss"] = MetricsRecord( + {"loss": evaluateres.loss}, ) # num_examples - recordset.set_metrics( - name=f"{res_str}.num_examples", - record=MetricsRecord({"num_examples": evaluateres.num_examples}), + recordset.metrics_records[f"{res_str}.num_examples"] = MetricsRecord( + {"num_examples": evaluateres.num_examples}, ) # metrics - recordset.set_configs( - name=f"{res_str}.metrics", - record=ConfigsRecord(evaluateres.metrics), # type: ignore + recordset.configs_records[f"{res_str}.metrics"] = ConfigsRecord( + evaluateres.metrics, # type: ignore ) # status @@ -312,9 +305,9 @@ def evaluateres_to_recordset(evaluateres: EvaluateRes) -> RecordSet: def recordset_to_getparametersins(recordset: RecordSet) -> GetParametersIns: """Derive GetParametersIns from a RecordSet object.""" - config_record = recordset.get_configs("getparametersins.config") - - config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record.data) + config_record = recordset.configs_records["getparametersins.config"] + # pylint: disable-next=protected-access + config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record._data) return GetParametersIns(config=config_dict) @@ -323,19 +316,22 @@ def getparametersins_to_recordset(getparameters_ins: GetParametersIns) -> Record """Construct a RecordSet from a GetParametersIns object.""" recordset = RecordSet() - recordset.set_configs( - name="getparametersins.config", - record=ConfigsRecord(getparameters_ins.config), # type: ignore + recordset.configs_records["getparametersins.config"] = ConfigsRecord( + getparameters_ins.config, # type: ignore ) return recordset -def getparametersres_to_recordset(getparametersres: GetParametersRes) -> RecordSet: +def getparametersres_to_recordset( + getparametersres: GetParametersRes, keep_input: bool +) -> RecordSet: """Construct a RecordSet from a GetParametersRes object.""" recordset = RecordSet() res_str = "getparametersres" - parameters_record = parameters_to_parametersrecord(getparametersres.parameters) - recordset.set_parameters(f"{res_str}.parameters", parameters_record) + parameters_record = parameters_to_parametersrecord( + getparametersres.parameters, keep_input=keep_input + ) + recordset.parameters_records[f"{res_str}.parameters"] = parameters_record # status recordset = _embed_status_into_recordset( @@ -345,11 +341,13 @@ def getparametersres_to_recordset(getparametersres: GetParametersRes) -> RecordS return recordset -def recordset_to_getparametersres(recordset: RecordSet) -> GetParametersRes: +def recordset_to_getparametersres( + recordset: RecordSet, keep_input: bool +) -> GetParametersRes: """Derive GetParametersRes from a RecordSet object.""" res_str = "getparametersres" parameters = parametersrecord_to_parameters( - recordset.get_parameters(f"{res_str}.parameters") + recordset.parameters_records[f"{res_str}.parameters"], keep_input=keep_input ) status = _extract_status_from_recordset(res_str, recordset) @@ -358,8 +356,9 @@ def recordset_to_getparametersres(recordset: RecordSet) -> GetParametersRes: def recordset_to_getpropertiesins(recordset: RecordSet) -> GetPropertiesIns: """Derive GetPropertiesIns from a RecordSet object.""" - config_record = recordset.get_configs("getpropertiesins.config") - config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record.data) + config_record = recordset.configs_records["getpropertiesins.config"] + # pylint: disable-next=protected-access + config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record._data) return GetPropertiesIns(config=config_dict) @@ -367,9 +366,8 @@ def recordset_to_getpropertiesins(recordset: RecordSet) -> GetPropertiesIns: def getpropertiesins_to_recordset(getpropertiesins: GetPropertiesIns) -> RecordSet: """Construct a RecordSet from a GetPropertiesRes object.""" recordset = RecordSet() - recordset.set_configs( - name="getpropertiesins.config", - record=ConfigsRecord(getpropertiesins.config), # type: ignore + recordset.configs_records["getpropertiesins.config"] = ConfigsRecord( + getpropertiesins.config, # type: ignore ) return recordset @@ -377,8 +375,9 @@ def getpropertiesins_to_recordset(getpropertiesins: GetPropertiesIns) -> RecordS def recordset_to_getpropertiesres(recordset: RecordSet) -> GetPropertiesRes: """Derive GetPropertiesRes from a RecordSet object.""" res_str = "getpropertiesres" - config_record = recordset.get_configs(f"{res_str}.properties") - properties = _check_mapping_from_recordscalartype_to_scalar(config_record.data) + config_record = recordset.configs_records[f"{res_str}.properties"] + # pylint: disable-next=protected-access + properties = _check_mapping_from_recordscalartype_to_scalar(config_record._data) status = _extract_status_from_recordset(res_str, recordset=recordset) @@ -389,9 +388,8 @@ def getpropertiesres_to_recordset(getpropertiesres: GetPropertiesRes) -> RecordS """Construct a RecordSet from a GetPropertiesRes object.""" recordset = RecordSet() res_str = "getpropertiesres" - recordset.set_configs( - name=f"{res_str}.properties", - record=ConfigsRecord(getpropertiesres.properties), # type: ignore + recordset.configs_records[f"{res_str}.properties"] = ConfigsRecord( + getpropertiesres.properties, # type: ignore ) # status recordset = _embed_status_into_recordset( diff --git a/src/py/flwr/common/recordset_compat_test.py b/src/py/flwr/common/recordset_compat_test.py index ad91cd3a42fc..288326dc9e83 100644 --- a/src/py/flwr/common/recordset_compat_test.py +++ b/src/py/flwr/common/recordset_compat_test.py @@ -15,9 +15,10 @@ """RecordSet from legacy messages tests.""" from copy import deepcopy -from typing import Dict +from typing import Callable, Dict import numpy as np +import pytest from .parameter import ndarrays_to_parameters from .recordset_compat import ( @@ -136,42 +137,88 @@ def _get_valid_getpropertiesres() -> GetPropertiesRes: ) -def test_fitins_to_recordset_and_back() -> None: +@pytest.mark.parametrize( + "keep_input, validate_freed_fn", + [ + ( + False, + lambda x, x_copy, y: len(x.parameters.tensors) == 0 and x_copy == y, + ), # check tensors were freed + ( + True, + lambda x, x_copy, y: x == y, + ), + ], +) +def test_fitins_to_recordset_and_back( + keep_input: bool, validate_freed_fn: Callable[[FitIns, FitIns, FitIns], bool] +) -> None: """Test conversion FitIns --> RecordSet --> FitIns.""" fitins = _get_valid_fitins() fitins_copy = deepcopy(fitins) - recordset = fitins_to_recordset(fitins, keep_input=False) + recordset = fitins_to_recordset(fitins, keep_input=keep_input) - fitins_ = recordset_to_fitins(recordset, keep_input=False) + fitins_ = recordset_to_fitins(recordset, keep_input=keep_input) - assert fitins_copy == fitins_ + assert validate_freed_fn(fitins, fitins_copy, fitins_) -def test_fitres_to_recordset_and_back() -> None: +@pytest.mark.parametrize( + "keep_input, validate_freed_fn", + [ + ( + False, + lambda x, x_copy, y: len(x.parameters.tensors) == 0 and x_copy == y, + ), # check tensors were freed + ( + True, + lambda x, x_copy, y: x == y, + ), + ], +) +def test_fitres_to_recordset_and_back( + keep_input: bool, validate_freed_fn: Callable[[FitRes, FitRes, FitRes], bool] +) -> None: """Test conversion FitRes --> RecordSet --> FitRes.""" fitres = _get_valid_fitres() fitres_copy = deepcopy(fitres) - recordset = fitres_to_recordset(fitres, keep_input=False) - fitres_ = recordset_to_fitres(recordset, keep_input=False) + recordset = fitres_to_recordset(fitres, keep_input=keep_input) + fitres_ = recordset_to_fitres(recordset, keep_input=keep_input) - assert fitres_copy == fitres_ + assert validate_freed_fn(fitres, fitres_copy, fitres_) -def test_evaluateins_to_recordset_and_back() -> None: +@pytest.mark.parametrize( + "keep_input, validate_freed_fn", + [ + ( + False, + lambda x, x_copy, y: len(x.parameters.tensors) == 0 and x_copy == y, + ), # check tensors were freed + ( + True, + lambda x, x_copy, y: x == y, + ), + ], +) +def test_evaluateins_to_recordset_and_back( + keep_input: bool, + validate_freed_fn: Callable[[EvaluateIns, EvaluateIns, EvaluateIns], bool], +) -> None: """Test conversion EvaluateIns --> RecordSet --> EvaluateIns.""" evaluateins = _get_valid_evaluateins() evaluateins_copy = deepcopy(evaluateins) - recordset = evaluateins_to_recordset(evaluateins, keep_input=False) + recordset = evaluateins_to_recordset(evaluateins, keep_input=keep_input) - evaluateins_ = recordset_to_evaluateins(recordset, keep_input=False) + evaluateins_ = recordset_to_evaluateins(recordset, keep_input=keep_input) - assert evaluateins_copy == evaluateins_ + assert validate_freed_fn(evaluateins, evaluateins_copy, evaluateins_) def test_evaluateres_to_recordset_and_back() -> None: @@ -222,13 +269,35 @@ def test_get_parameters_ins_to_recordset_and_back() -> None: assert getparameters_ins_copy == getparameters_ins_ -def test_get_parameters_res_to_recordset_and_back() -> None: +@pytest.mark.parametrize( + "keep_input, validate_freed_fn", + [ + ( + False, + lambda x, x_copy, y: len(x.parameters.tensors) == 0 and x_copy == y, + ), # check tensors were freed + ( + True, + lambda x, x_copy, y: x == y, + ), + ], +) +def test_get_parameters_res_to_recordset_and_back( + keep_input: bool, + validate_freed_fn: Callable[ + [GetParametersRes, GetParametersRes, GetParametersRes], bool + ], +) -> None: """Test conversion GetParametersRes --> RecordSet --> GetParametersRes.""" getparameteres_res = _get_valid_getparametersres() getparameters_res_copy = deepcopy(getparameteres_res) - recordset = getparametersres_to_recordset(getparameteres_res) - getparameteres_res_ = recordset_to_getparametersres(recordset) + recordset = getparametersres_to_recordset(getparameteres_res, keep_input=keep_input) + getparameteres_res_ = recordset_to_getparametersres( + recordset, keep_input=keep_input + ) - assert getparameters_res_copy == getparameteres_res_ + assert validate_freed_fn( + getparameteres_res, getparameters_res_copy, getparameteres_res_ + ) diff --git a/src/py/flwr/common/secure_aggregation/ndarrays_arithmetic.py b/src/py/flwr/common/secure_aggregation/ndarrays_arithmetic.py index 57afa56b7a08..e926a9531bea 100644 --- a/src/py/flwr/common/secure_aggregation/ndarrays_arithmetic.py +++ b/src/py/flwr/common/secure_aggregation/ndarrays_arithmetic.py @@ -15,7 +15,7 @@ """Utility functions for performing operations on Numpy NDArrays.""" -from typing import Any, List, Tuple +from typing import Any, List, Tuple, Union import numpy as np from numpy.typing import DTypeLike, NDArray @@ -68,14 +68,14 @@ def parameters_mod(parameters: List[NDArray[Any]], divisor: int) -> List[NDArray def parameters_multiply( - parameters: List[NDArray[Any]], multiplier: int + parameters: List[NDArray[Any]], multiplier: Union[int, float] ) -> List[NDArray[Any]]: - """Multiply parameters by an integer multiplier.""" + """Multiply parameters by an integer/float multiplier.""" return [parameters[idx] * multiplier for idx in range(len(parameters))] def parameters_divide( - parameters: List[NDArray[Any]], divisor: int + parameters: List[NDArray[Any]], divisor: Union[int, float] ) -> List[NDArray[Any]]: - """Divide weight by an integer divisor.""" + """Divide weight by an integer/float divisor.""" return [parameters[idx] / divisor for idx in range(len(parameters))] diff --git a/src/py/flwr/common/secure_aggregation/secaggplus_constants.py b/src/py/flwr/common/secure_aggregation/secaggplus_constants.py index 8dd21a6016f1..8a15908c13c5 100644 --- a/src/py/flwr/common/secure_aggregation/secaggplus_constants.py +++ b/src/py/flwr/common/secure_aggregation/secaggplus_constants.py @@ -15,30 +15,55 @@ """Constants for the SecAgg/SecAgg+ protocol.""" -# Names of stages -STAGE_SETUP = "setup" -STAGE_SHARE_KEYS = "share_keys" -STAGE_COLLECT_MASKED_INPUT = "collect_masked_input" -STAGE_UNMASK = "unmask" -STAGES = (STAGE_SETUP, STAGE_SHARE_KEYS, STAGE_COLLECT_MASKED_INPUT, STAGE_UNMASK) - -# All valid keys in received/replied `named_values` dictionaries -KEY_STAGE = "stage" -KEY_SAMPLE_NUMBER = "sample_num" -KEY_SECURE_ID = "secure_id" -KEY_SHARE_NUMBER = "share_num" -KEY_THRESHOLD = "threshold" -KEY_CLIPPING_RANGE = "clipping_range" -KEY_TARGET_RANGE = "target_range" -KEY_MOD_RANGE = "mod_range" -KEY_PUBLIC_KEY_1 = "pk1" -KEY_PUBLIC_KEY_2 = "pk2" -KEY_DESTINATION_LIST = "dsts" -KEY_CIPHERTEXT_LIST = "ctxts" -KEY_SOURCE_LIST = "srcs" -KEY_PARAMETERS = "params" -KEY_MASKED_PARAMETERS = "masked_params" -KEY_ACTIVE_SECURE_ID_LIST = "active_sids" -KEY_DEAD_SECURE_ID_LIST = "dead_sids" -KEY_SECURE_ID_LIST = "sids" -KEY_SHARE_LIST = "shares" +from __future__ import annotations + +RECORD_KEY_STATE = "secaggplus_state" +RECORD_KEY_CONFIGS = "secaggplus_configs" +RATIO_QUANTIZATION_RANGE = 1073741824 # 1 << 30 + + +class Stage: + """Stages for the SecAgg+ protocol.""" + + SETUP = "setup" + SHARE_KEYS = "share_keys" + COLLECT_MASKED_VECTORS = "collect_masked_vectors" + UNMASK = "unmask" + _stages = (SETUP, SHARE_KEYS, COLLECT_MASKED_VECTORS, UNMASK) + + @classmethod + def all(cls) -> tuple[str, str, str, str]: + """Return all stages.""" + return cls._stages + + def __new__(cls) -> Stage: + """Prevent instantiation.""" + raise TypeError(f"{cls.__name__} cannot be instantiated.") + + +class Key: + """Keys for the configs in the ConfigsRecord.""" + + STAGE = "stage" + SAMPLE_NUMBER = "sample_num" + SHARE_NUMBER = "share_num" + THRESHOLD = "threshold" + CLIPPING_RANGE = "clipping_range" + TARGET_RANGE = "target_range" + MOD_RANGE = "mod_range" + MAX_WEIGHT = "max_weight" + PUBLIC_KEY_1 = "pk1" + PUBLIC_KEY_2 = "pk2" + DESTINATION_LIST = "dsts" + CIPHERTEXT_LIST = "ctxts" + SOURCE_LIST = "srcs" + PARAMETERS = "params" + MASKED_PARAMETERS = "masked_params" + ACTIVE_NODE_ID_LIST = "active_nids" + DEAD_NODE_ID_LIST = "dead_nids" + NODE_ID_LIST = "nids" + SHARE_LIST = "shares" + + def __new__(cls) -> Key: + """Prevent instantiation.""" + raise TypeError(f"{cls.__name__} cannot be instantiated.") diff --git a/src/py/flwr/common/secure_aggregation/secaggplus_utils.py b/src/py/flwr/common/secure_aggregation/secaggplus_utils.py index def677e9d5d9..c373573477b9 100644 --- a/src/py/flwr/common/secure_aggregation/secaggplus_utils.py +++ b/src/py/flwr/common/secure_aggregation/secaggplus_utils.py @@ -23,16 +23,16 @@ def share_keys_plaintext_concat( - source: int, destination: int, b_share: bytes, sk_share: bytes + src_node_id: int, dst_node_id: int, b_share: bytes, sk_share: bytes ) -> bytes: """Combine arguments to bytes. Parameters ---------- - source : int - the secure ID of the source. - destination : int - the secure ID of the destination. + src_node_id : int + the node ID of the source. + dst_node_id : int + the node ID of the destination. b_share : bytes the private key share of the source sent to the destination. sk_share : bytes @@ -45,8 +45,8 @@ def share_keys_plaintext_concat( """ return b"".join( [ - int.to_bytes(source, 4, "little"), - int.to_bytes(destination, 4, "little"), + int.to_bytes(src_node_id, 8, "little", signed=True), + int.to_bytes(dst_node_id, 8, "little", signed=True), int.to_bytes(len(b_share), 4, "little"), b_share, sk_share, @@ -64,21 +64,21 @@ def share_keys_plaintext_separate(plaintext: bytes) -> Tuple[int, int, bytes, by Returns ------- - source : int - the secure ID of the source. - destination : int - the secure ID of the destination. + src_node_id : int + the node ID of the source. + dst_node_id : int + the node ID of the destination. b_share : bytes the private key share of the source sent to the destination. sk_share : bytes the secret key share of the source sent to the destination. """ src, dst, mark = ( - int.from_bytes(plaintext[:4], "little"), - int.from_bytes(plaintext[4:8], "little"), - int.from_bytes(plaintext[8:12], "little"), + int.from_bytes(plaintext[:8], "little", signed=True), + int.from_bytes(plaintext[8:16], "little", signed=True), + int.from_bytes(plaintext[16:20], "little"), ) - ret = (src, dst, plaintext[12 : 12 + mark], plaintext[12 + mark :]) + ret = (src, dst, plaintext[20 : 20 + mark], plaintext[20 + mark :]) return ret diff --git a/src/py/flwr/common/serde.py b/src/py/flwr/common/serde.py index 2600d46edddc..6c7a077d2f9f 100644 --- a/src/py/flwr/common/serde.py +++ b/src/py/flwr/common/serde.py @@ -17,9 +17,11 @@ from typing import Any, Dict, List, MutableMapping, OrderedDict, Type, TypeVar, cast -from google.protobuf.message import Message +from google.protobuf.message import Message as GrpcMessage # pylint: disable=E0611 +from flwr.proto.error_pb2 import Error as ProtoError +from flwr.proto.node_pb2 import Node from flwr.proto.recordset_pb2 import Array as ProtoArray from flwr.proto.recordset_pb2 import BoolList, BytesList from flwr.proto.recordset_pb2 import ConfigsRecord as ProtoConfigsRecord @@ -30,7 +32,7 @@ from flwr.proto.recordset_pb2 import ParametersRecord as ProtoParametersRecord from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet from flwr.proto.recordset_pb2 import Sint64List, StringList -from flwr.proto.task_pb2 import Value +from flwr.proto.task_pb2 import Task, TaskIns, TaskRes from flwr.proto.transport_pb2 import ( ClientMessage, Code, @@ -42,147 +44,9 @@ ) # pylint: enable=E0611 -from . import typing -from .configsrecord import ConfigsRecord -from .metricsrecord import MetricsRecord -from .parametersrecord import Array, ParametersRecord -from .recordset import RecordSet - -# === ServerMessage message === - - -def server_message_to_proto(server_message: typing.ServerMessage) -> ServerMessage: - """Serialize `ServerMessage` to ProtoBuf.""" - if server_message.get_properties_ins is not None: - return ServerMessage( - get_properties_ins=get_properties_ins_to_proto( - server_message.get_properties_ins, - ) - ) - if server_message.get_parameters_ins is not None: - return ServerMessage( - get_parameters_ins=get_parameters_ins_to_proto( - server_message.get_parameters_ins, - ) - ) - if server_message.fit_ins is not None: - return ServerMessage( - fit_ins=fit_ins_to_proto( - server_message.fit_ins, - ) - ) - if server_message.evaluate_ins is not None: - return ServerMessage( - evaluate_ins=evaluate_ins_to_proto( - server_message.evaluate_ins, - ) - ) - raise ValueError( - "No instruction set in ServerMessage, cannot serialize to ProtoBuf" - ) - - -def server_message_from_proto( - server_message_proto: ServerMessage, -) -> typing.ServerMessage: - """Deserialize `ServerMessage` from ProtoBuf.""" - field = server_message_proto.WhichOneof("msg") - if field == "get_properties_ins": - return typing.ServerMessage( - get_properties_ins=get_properties_ins_from_proto( - server_message_proto.get_properties_ins, - ) - ) - if field == "get_parameters_ins": - return typing.ServerMessage( - get_parameters_ins=get_parameters_ins_from_proto( - server_message_proto.get_parameters_ins, - ) - ) - if field == "fit_ins": - return typing.ServerMessage( - fit_ins=fit_ins_from_proto( - server_message_proto.fit_ins, - ) - ) - if field == "evaluate_ins": - return typing.ServerMessage( - evaluate_ins=evaluate_ins_from_proto( - server_message_proto.evaluate_ins, - ) - ) - raise ValueError( - "Unsupported instruction in ServerMessage, cannot deserialize from ProtoBuf" - ) - - -# === ClientMessage message === - - -def client_message_to_proto(client_message: typing.ClientMessage) -> ClientMessage: - """Serialize `ClientMessage` to ProtoBuf.""" - if client_message.get_properties_res is not None: - return ClientMessage( - get_properties_res=get_properties_res_to_proto( - client_message.get_properties_res, - ) - ) - if client_message.get_parameters_res is not None: - return ClientMessage( - get_parameters_res=get_parameters_res_to_proto( - client_message.get_parameters_res, - ) - ) - if client_message.fit_res is not None: - return ClientMessage( - fit_res=fit_res_to_proto( - client_message.fit_res, - ) - ) - if client_message.evaluate_res is not None: - return ClientMessage( - evaluate_res=evaluate_res_to_proto( - client_message.evaluate_res, - ) - ) - raise ValueError( - "No instruction set in ClientMessage, cannot serialize to ProtoBuf" - ) - - -def client_message_from_proto( - client_message_proto: ClientMessage, -) -> typing.ClientMessage: - """Deserialize `ClientMessage` from ProtoBuf.""" - field = client_message_proto.WhichOneof("msg") - if field == "get_properties_res": - return typing.ClientMessage( - get_properties_res=get_properties_res_from_proto( - client_message_proto.get_properties_res, - ) - ) - if field == "get_parameters_res": - return typing.ClientMessage( - get_parameters_res=get_parameters_res_from_proto( - client_message_proto.get_parameters_res, - ) - ) - if field == "fit_res": - return typing.ClientMessage( - fit_res=fit_res_from_proto( - client_message_proto.fit_res, - ) - ) - if field == "evaluate_res": - return typing.ClientMessage( - evaluate_res=evaluate_res_from_proto( - client_message_proto.evaluate_res, - ) - ) - raise ValueError( - "Unsupported instruction in ClientMessage, cannot deserialize from ProtoBuf" - ) - +from . import Array, ConfigsRecord, MetricsRecord, ParametersRecord, RecordSet, typing +from .message import Error, Message, Metadata +from .record.typeddict import TypedDict # === Parameters message === @@ -208,26 +72,9 @@ def reconnect_ins_to_proto(ins: typing.ReconnectIns) -> ServerMessage.ReconnectI return ServerMessage.ReconnectIns() -def reconnect_ins_from_proto(msg: ServerMessage.ReconnectIns) -> typing.ReconnectIns: - """Deserialize `ReconnectIns` from ProtoBuf.""" - return typing.ReconnectIns(seconds=msg.seconds) - - # === DisconnectRes message === -def disconnect_res_to_proto(res: typing.DisconnectRes) -> ClientMessage.DisconnectRes: - """Serialize `DisconnectRes` to ProtoBuf.""" - reason_proto = Reason.UNKNOWN - if res.reason == "RECONNECT": - reason_proto = Reason.RECONNECT - elif res.reason == "POWER_DISCONNECTED": - reason_proto = Reason.POWER_DISCONNECTED - elif res.reason == "WIFI_UNAVAILABLE": - reason_proto = Reason.WIFI_UNAVAILABLE - return ClientMessage.DisconnectRes(reason=reason_proto) - - def disconnect_res_from_proto(msg: ClientMessage.DisconnectRes) -> typing.DisconnectRes: """Deserialize `DisconnectRes` from ProtoBuf.""" if msg.reason == Reason.RECONNECT: @@ -508,7 +355,7 @@ def scalar_from_proto(scalar_msg: Scalar) -> typing.Scalar: return cast(typing.Scalar, scalar) -# === Value messages === +# === Record messages === _type_to_field = { @@ -518,8 +365,6 @@ def scalar_from_proto(scalar_msg: Scalar) -> typing.Scalar: str: "string", bytes: "bytes", } - - _list_type_to_class_and_field = { float: (DoubleList, "double_list"), int: (Sint64List, "sint64_list"), @@ -527,85 +372,21 @@ def scalar_from_proto(scalar_msg: Scalar) -> typing.Scalar: str: (StringList, "string_list"), bytes: (BytesList, "bytes_list"), } - - -def _check_value(value: typing.Value) -> None: - if isinstance(value, tuple(_type_to_field.keys())): - return - if isinstance(value, list): - if len(value) > 0 and isinstance(value[0], tuple(_type_to_field.keys())): - data_type = type(value[0]) - for element in value: - if isinstance(element, data_type): - continue - raise TypeError( - f"Inconsistent type: the types of elements in the list must " - f"be the same (expected {data_type}, but got {type(element)})." - ) - else: - raise TypeError( - f"Accepted types: {bool, bytes, float, int, str} or " - f"list of these types." - ) - - -def value_to_proto(value: typing.Value) -> Value: - """Serialize `Value` to ProtoBuf.""" - _check_value(value) - - arg = {} - if isinstance(value, list): - msg_class, field_name = _list_type_to_class_and_field[ - type(value[0]) if len(value) > 0 else int - ] - arg[field_name] = msg_class(vals=value) - else: - arg[_type_to_field[type(value)]] = value - return Value(**arg) - - -def value_from_proto(value_msg: Value) -> typing.Value: - """Deserialize `Value` from ProtoBuf.""" - value_field = cast(str, value_msg.WhichOneof("value")) - if value_field.endswith("list"): - value = list(getattr(value_msg, value_field).vals) - else: - value = getattr(value_msg, value_field) - return cast(typing.Value, value) - - -# === Named Values === - - -def named_values_to_proto( - named_values: Dict[str, typing.Value], -) -> Dict[str, Value]: - """Serialize named values to ProtoBuf.""" - return {name: value_to_proto(value) for name, value in named_values.items()} - - -def named_values_from_proto( - named_values_proto: MutableMapping[str, Value] -) -> Dict[str, typing.Value]: - """Deserialize named values from ProtoBuf.""" - return {name: value_from_proto(value) for name, value in named_values_proto.items()} - - -# === Record messages === - - T = TypeVar("T") def _record_value_to_proto( value: Any, allowed_types: List[type], proto_class: Type[T] ) -> T: - """Serialize `*RecordValue` to ProtoBuf.""" + """Serialize `*RecordValue` to ProtoBuf. + + Note: `bool` MUST be put in the front of allowd_types if it exists. + """ arg = {} for t in allowed_types: # Single element # Note: `isinstance(False, int) == True`. - if type(value) == t: # pylint: disable=C0123 + if isinstance(value, t): arg[_type_to_field[t]] = value return proto_class(**arg) # List @@ -620,7 +401,7 @@ def _record_value_to_proto( ) -def _record_value_from_proto(value_proto: Message) -> Any: +def _record_value_from_proto(value_proto: GrpcMessage) -> Any: """Deserialize `*RecordValue` from ProtoBuf.""" value_field = cast(str, value_proto.WhichOneof("value")) if value_field.endswith("list"): @@ -631,9 +412,18 @@ def _record_value_from_proto(value_proto: Message) -> Any: def _record_value_dict_to_proto( - value_dict: Dict[str, Any], allowed_types: List[type], value_proto_class: Type[T] + value_dict: TypedDict[str, Any], + allowed_types: List[type], + value_proto_class: Type[T], ) -> Dict[str, T]: - """Serialize the record value dict to ProtoBuf.""" + """Serialize the record value dict to ProtoBuf. + + Note: `bool` MUST be put in the front of allowd_types if it exists. + """ + # Move bool to the front + if bool in allowed_types and allowed_types[0] != bool: + allowed_types.remove(bool) + allowed_types.insert(0, bool) def proto(_v: Any) -> T: return _record_value_to_proto(_v, allowed_types, value_proto_class) @@ -666,8 +456,8 @@ def array_from_proto(array_proto: ProtoArray) -> Array: def parameters_record_to_proto(record: ParametersRecord) -> ProtoParametersRecord: """Serialize ParametersRecord to ProtoBuf.""" return ProtoParametersRecord( - data_keys=record.data.keys(), - data_values=map(array_to_proto, record.data.values()), + data_keys=record.keys(), + data_values=map(array_to_proto, record.values()), ) @@ -686,9 +476,7 @@ def parameters_record_from_proto( def metrics_record_to_proto(record: MetricsRecord) -> ProtoMetricsRecord: """Serialize MetricsRecord to ProtoBuf.""" return ProtoMetricsRecord( - data=_record_value_dict_to_proto( - record.data, [float, int], ProtoMetricsRecordValue - ) + data=_record_value_dict_to_proto(record, [float, int], ProtoMetricsRecordValue) ) @@ -707,7 +495,9 @@ def configs_record_to_proto(record: ConfigsRecord) -> ProtoConfigsRecord: """Serialize ConfigsRecord to ProtoBuf.""" return ProtoConfigsRecord( data=_record_value_dict_to_proto( - record.data, [int, float, bool, str, bytes], ProtoConfigsRecordValue + record, + [bool, int, float, str, bytes], + ProtoConfigsRecordValue, ) ) @@ -723,6 +513,21 @@ def configs_record_from_proto(record_proto: ProtoConfigsRecord) -> ConfigsRecord ) +# === Error message === + + +def error_to_proto(error: Error) -> ProtoError: + """Serialize Error to ProtoBuf.""" + reason = error.reason if error.reason else "" + return ProtoError(code=error.code, reason=reason) + + +def error_from_proto(error_proto: ProtoError) -> Error: + """Deserialize Error from ProtoBuf.""" + reason = error_proto.reason if len(error_proto.reason) > 0 else None + return Error(code=error_proto.code, reason=reason) + + # === RecordSet message === @@ -730,24 +535,133 @@ def recordset_to_proto(recordset: RecordSet) -> ProtoRecordSet: """Serialize RecordSet to ProtoBuf.""" return ProtoRecordSet( parameters={ - k: parameters_record_to_proto(v) for k, v in recordset.parameters.items() + k: parameters_record_to_proto(v) + for k, v in recordset.parameters_records.items() + }, + metrics={ + k: metrics_record_to_proto(v) for k, v in recordset.metrics_records.items() + }, + configs={ + k: configs_record_to_proto(v) for k, v in recordset.configs_records.items() }, - metrics={k: metrics_record_to_proto(v) for k, v in recordset.metrics.items()}, - configs={k: configs_record_to_proto(v) for k, v in recordset.configs.items()}, ) def recordset_from_proto(recordset_proto: ProtoRecordSet) -> RecordSet: """Deserialize RecordSet from ProtoBuf.""" return RecordSet( - parameters={ + parameters_records={ k: parameters_record_from_proto(v) for k, v in recordset_proto.parameters.items() }, - metrics={ + metrics_records={ k: metrics_record_from_proto(v) for k, v in recordset_proto.metrics.items() }, - configs={ + configs_records={ k: configs_record_from_proto(v) for k, v in recordset_proto.configs.items() }, ) + + +# === Message === + + +def message_to_taskins(message: Message) -> TaskIns: + """Create a TaskIns from the Message.""" + md = message.metadata + return TaskIns( + group_id=md.group_id, + run_id=md.run_id, + task=Task( + producer=Node(node_id=0, anonymous=True), # Assume driver node + consumer=Node(node_id=md.dst_node_id, anonymous=False), + ttl=md.ttl, + ancestry=[md.reply_to_message] if md.reply_to_message != "" else [], + task_type=md.message_type, + recordset=( + recordset_to_proto(message.content) if message.has_content() else None + ), + error=error_to_proto(message.error) if message.has_error() else None, + ), + ) + + +def message_from_taskins(taskins: TaskIns) -> Message: + """Create a Message from the TaskIns.""" + # Retrieve the Metadata + metadata = Metadata( + run_id=taskins.run_id, + message_id=taskins.task_id, + src_node_id=taskins.task.producer.node_id, + dst_node_id=taskins.task.consumer.node_id, + reply_to_message=taskins.task.ancestry[0] if taskins.task.ancestry else "", + group_id=taskins.group_id, + ttl=taskins.task.ttl, + message_type=taskins.task.task_type, + ) + + # Construct Message + return Message( + metadata=metadata, + content=( + recordset_from_proto(taskins.task.recordset) + if taskins.task.HasField("recordset") + else None + ), + error=( + error_from_proto(taskins.task.error) + if taskins.task.HasField("error") + else None + ), + ) + + +def message_to_taskres(message: Message) -> TaskRes: + """Create a TaskRes from the Message.""" + md = message.metadata + return TaskRes( + task_id="", # This will be generated by the server + group_id=md.group_id, + run_id=md.run_id, + task=Task( + producer=Node(node_id=md.src_node_id, anonymous=False), + consumer=Node(node_id=0, anonymous=True), # Assume driver node + ttl=md.ttl, + ancestry=[md.reply_to_message] if md.reply_to_message != "" else [], + task_type=md.message_type, + recordset=( + recordset_to_proto(message.content) if message.has_content() else None + ), + error=error_to_proto(message.error) if message.has_error() else None, + ), + ) + + +def message_from_taskres(taskres: TaskRes) -> Message: + """Create a Message from the TaskIns.""" + # Retrieve the MetaData + metadata = Metadata( + run_id=taskres.run_id, + message_id=taskres.task_id, + src_node_id=taskres.task.producer.node_id, + dst_node_id=taskres.task.consumer.node_id, + reply_to_message=taskres.task.ancestry[0] if taskres.task.ancestry else "", + group_id=taskres.group_id, + ttl=taskres.task.ttl, + message_type=taskres.task.task_type, + ) + + # Construct the Message + return Message( + metadata=metadata, + content=( + recordset_from_proto(taskres.task.recordset) + if taskres.task.HasField("recordset") + else None + ), + error=( + error_from_proto(taskres.task.error) + if taskres.task.HasField("error") + else None + ), + ) diff --git a/src/py/flwr/common/serde_test.py b/src/py/flwr/common/serde_test.py index 53f40eee5e53..8596e5d2f330 100644 --- a/src/py/flwr/common/serde_test.py +++ b/src/py/flwr/common/serde_test.py @@ -14,8 +14,11 @@ # ============================================================================== """(De-)serialization tests.""" +import random +import string +from typing import Any, Callable, Optional, OrderedDict, Type, TypeVar, Union, cast -from typing import Dict, OrderedDict, Union, cast +import pytest # pylint: disable=E0611 from flwr.proto import transport_pb2 as pb2 @@ -26,20 +29,19 @@ from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet # pylint: enable=E0611 -from . import typing -from .configsrecord import ConfigsRecord -from .metricsrecord import MetricsRecord -from .parametersrecord import Array, ParametersRecord -from .recordset import RecordSet +from . import Array, ConfigsRecord, MetricsRecord, ParametersRecord, RecordSet, typing +from .message import Error, Message, Metadata from .serde import ( array_from_proto, array_to_proto, configs_record_from_proto, configs_record_to_proto, + message_from_taskins, + message_from_taskres, + message_to_taskins, + message_to_taskres, metrics_record_from_proto, metrics_record_to_proto, - named_values_from_proto, - named_values_to_proto, parameters_record_from_proto, parameters_record_to_proto, recordset_from_proto, @@ -48,8 +50,6 @@ scalar_to_proto, status_from_proto, status_to_proto, - value_from_proto, - value_to_proto, ) @@ -100,90 +100,135 @@ def test_status_from_proto() -> None: assert actual_status == status -def test_value_serialization_deserialization() -> None: - """Test if values are identical after (de-)serialization.""" - # Prepare - values = [ - # boolean scalar and list - True, - [True, False, False, True], - # bytes scalar and list - b"test \x01\x02\x03 !@#$%^&*()", - [b"\x0a\x0b", b"\x0c\x0d\x0e", b"\x0f"], - # float scalar and list - 3.14, - [2.714, -0.012], - # integer scalar and list - 23, - [123456], - # string scalar and list - "abcdefghijklmnopqrstuvwxy", - ["456hgdhfd", "1234567890123456789012345678901", "I'm a string."], - # empty list - [], - ] - - for value in values: - # Execute - serialized = value_to_proto(cast(typing.Value, value)) - deserialized = value_from_proto(serialized) - - # Assert - if isinstance(value, list): - assert isinstance(deserialized, list) - assert len(value) == len(deserialized) - for elm1, elm2 in zip(value, deserialized): - assert elm1 == elm2 - else: - assert value == deserialized - - -def test_named_values_serialization_deserialization() -> None: - """Test if named values is identical after (de-)serialization.""" - # Prepare - values = [ - # boolean scalar and list - True, - [True, False, False, True], - # bytes scalar and list - b"test \x01\x02\x03 !@#$%^&*()", - [b"\x0a\x0b", b"\x0c\x0d\x0e", b"\x0f"], - # float scalar and list - 3.14, - [2.714, -0.012], - # integer scalar and list - 23, - [123456], - # string scalar and list - "abcdefghijklmnopqrstuvwxy", - ["456hgdhfd", "1234567890123456789012345678901", "I'm a string."], - # empty list - [], - ] - named_values = {f"value {i}": value for i, value in enumerate(values)} - - # Execute - serialized = named_values_to_proto(cast(Dict[str, typing.Value], named_values)) - deserialized = named_values_from_proto(serialized) - - # Assert - assert len(named_values) == len(deserialized) - for name in named_values: - expected = named_values[name] - actual = deserialized[name] - if isinstance(expected, list): - assert isinstance(actual, list) - assert len(expected) == len(actual) - for elm1, elm2 in zip(expected, actual): - assert elm1 == elm2 +T = TypeVar("T") + + +class RecordMaker: + """A record maker based on a seeded random number generator.""" + + def __init__(self, state: int = 42) -> None: + self.rng = random.Random(state) + + def randbytes(self, n: int) -> bytes: + """Create a bytes.""" + return self.rng.getrandbits(n * 8).to_bytes(n, "little") + + def get_str(self, length: Optional[int] = None) -> str: + """Create a string.""" + char_pool = ( + string.ascii_letters + string.digits + " !@#$%^&*()_-+=[]|;':,./<>?{}" + ) + if length is None: + length = self.rng.randint(1, 10) + return "".join(self.rng.choices(char_pool, k=length)) + + def get_value(self, dtype: Type[T]) -> T: + """Create a value of a given type.""" + ret: Any = None + if dtype == bool: + ret = self.rng.random() < 0.5 + elif dtype == str: + ret = self.get_str(self.rng.randint(10, 100)) + elif dtype == int: + ret = self.rng.randint(-1 << 30, 1 << 30) + elif dtype == float: + ret = (self.rng.random() - 0.5) * (2.0 ** self.rng.randint(0, 50)) + elif dtype == bytes: + ret = self.randbytes(self.rng.randint(10, 100)) else: - assert expected == actual + raise NotImplementedError(f"Unsupported dtype: {dtype}") + return cast(T, ret) + + def array(self) -> Array: + """Create a Array.""" + dtypes = ("float", "int") + stypes = ("torch", "tf", "numpy") + max_shape_size = 100 + max_shape_dim = 10 + min_max_bytes_size = (10, 1000) + + dtype = self.rng.choice(dtypes) + shape = [ + self.rng.randint(1, max_shape_size) + for _ in range(self.rng.randint(1, max_shape_dim)) + ] + stype = self.rng.choice(stypes) + data = self.randbytes(self.rng.randint(*min_max_bytes_size)) + return Array(dtype=dtype, shape=shape, stype=stype, data=data) + + def parameters_record(self) -> ParametersRecord: + """Create a ParametersRecord.""" + num_arrays = self.rng.randint(1, 5) + arrays = OrderedDict( + [(self.get_str(), self.array()) for i in range(num_arrays)] + ) + return ParametersRecord(arrays, keep_input=False) + + def metrics_record(self) -> MetricsRecord: + """Create a MetricsRecord.""" + num_entries = self.rng.randint(1, 5) + types = (float, int) + return MetricsRecord( + metrics_dict={ + self.get_str(): self.get_value(self.rng.choice(types)) + for _ in range(num_entries) + }, + keep_input=False, + ) + + def configs_record(self) -> ConfigsRecord: + """Create a ConfigsRecord.""" + num_entries = self.rng.randint(1, 5) + types = (str, int, float, bytes, bool) + return ConfigsRecord( + configs_dict={ + self.get_str(): self.get_value(self.rng.choice(types)) + for _ in range(num_entries) + }, + keep_input=False, + ) + + def recordset( + self, + num_params_records: int, + num_metrics_records: int, + num_configs_records: int, + ) -> RecordSet: + """Create a RecordSet.""" + return RecordSet( + parameters_records={ + self.get_str(): self.parameters_record() + for _ in range(num_params_records) + }, + metrics_records={ + self.get_str(): self.metrics_record() + for _ in range(num_metrics_records) + }, + configs_records={ + self.get_str(): self.configs_record() + for _ in range(num_configs_records) + }, + ) + + def metadata(self) -> Metadata: + """Create a Metadata.""" + return Metadata( + run_id=self.rng.randint(0, 1 << 30), + message_id=self.get_str(64), + group_id=self.get_str(30), + src_node_id=self.rng.randint(0, 1 << 63), + dst_node_id=self.rng.randint(0, 1 << 63), + reply_to_message=self.get_str(64), + ttl=self.get_str(10), + message_type=self.get_str(10), + ) def test_array_serialization_deserialization() -> None: """Test serialization and deserialization of Array.""" # Prepare - original = Array(dtype="float", shape=[2, 2], stype="dense", data=b"1234") + maker = RecordMaker() + original = maker.array() # Execute proto = array_to_proto(original) @@ -197,15 +242,8 @@ def test_array_serialization_deserialization() -> None: def test_parameters_record_serialization_deserialization() -> None: """Test serialization and deserialization of ParametersRecord.""" # Prepare - original = ParametersRecord( - array_dict=OrderedDict( - [ - ("k1", Array(dtype="float", shape=[2, 2], stype="dense", data=b"1234")), - ("k2", Array(dtype="int", shape=[3], stype="sparse", data=b"567")), - ] - ), - keep_input=False, - ) + maker = RecordMaker() + original = maker.parameters_record() # Execute proto = parameters_record_to_proto(original) @@ -213,15 +251,14 @@ def test_parameters_record_serialization_deserialization() -> None: # Assert assert isinstance(proto, ProtoParametersRecord) - assert original.data == deserialized.data + assert original == deserialized def test_metrics_record_serialization_deserialization() -> None: """Test serialization and deserialization of MetricsRecord.""" # Prepare - original = MetricsRecord( - metrics_dict={"accuracy": 0.95, "loss": 0.1}, keep_input=False - ) + maker = RecordMaker() + original = maker.metrics_record() # Execute proto = metrics_record_to_proto(original) @@ -229,15 +266,14 @@ def test_metrics_record_serialization_deserialization() -> None: # Assert assert isinstance(proto, ProtoMetricsRecord) - assert original.data == deserialized.data + assert original == deserialized def test_configs_record_serialization_deserialization() -> None: """Test serialization and deserialization of ConfigsRecord.""" # Prepare - original = ConfigsRecord( - configs_dict={"learning_rate": 0.01, "batch_size": 32}, keep_input=False - ) + maker = RecordMaker() + original = maker.configs_record() # Execute proto = configs_record_to_proto(original) @@ -245,60 +281,14 @@ def test_configs_record_serialization_deserialization() -> None: # Assert assert isinstance(proto, ProtoConfigsRecord) - assert original.data == deserialized.data + assert original == deserialized def test_recordset_serialization_deserialization() -> None: """Test serialization and deserialization of RecordSet.""" # Prepare - encoder_params_record = ParametersRecord( - array_dict=OrderedDict( - [ - ( - "k1", - Array(dtype="float", shape=[2, 2], stype="dense", data=b"1234"), - ), - ("k2", Array(dtype="int", shape=[3], stype="sparse", data=b"567")), - ] - ), - keep_input=False, - ) - decoder_params_record = ParametersRecord( - array_dict=OrderedDict( - [ - ( - "k1", - Array( - dtype="float", shape=[32, 32, 4], stype="dense", data=b"0987" - ), - ), - ] - ), - keep_input=False, - ) - - original = RecordSet( - parameters={ - "encoder_parameters": encoder_params_record, - "decoder_parameters": decoder_params_record, - }, - metrics={ - "acc_metrics": MetricsRecord( - metrics_dict={"accuracy": 0.95, "loss": 0.1}, keep_input=False - ) - }, - configs={ - "my_configs": ConfigsRecord( - configs_dict={ - "learning_rate": 0.01, - "batch_size": 32, - "public_key": b"21f8sioj@!#", - "log": "Hello, world!", - }, - keep_input=False, - ) - }, - ) + maker = RecordMaker(state=0) + original = maker.recordset(2, 2, 1) # Execute proto = recordset_to_proto(original) @@ -307,3 +297,93 @@ def test_recordset_serialization_deserialization() -> None: # Assert assert isinstance(proto, ProtoRecordSet) assert original == deserialized + + +@pytest.mark.parametrize( + "content_fn, error_fn", + [ + ( + lambda maker: maker.recordset(1, 1, 1), + None, + ), # check when only content is set + (None, lambda code: Error(code=code)), # check when only error is set + ], +) +def test_message_to_and_from_taskins( + content_fn: Callable[ + [ + RecordMaker, + ], + RecordSet, + ], + error_fn: Callable[[int], Error], +) -> None: + """Test Message to and from TaskIns.""" + # Prepare + + maker = RecordMaker(state=1) + metadata = maker.metadata() + # pylint: disable-next=protected-access + metadata._src_node_id = 0 # Assume driver node + + original = Message( + metadata=metadata, + content=None if content_fn is None else content_fn(maker), + error=None if error_fn is None else error_fn(0), + ) + + # Execute + taskins = message_to_taskins(original) + taskins.task_id = metadata.message_id + deserialized = message_from_taskins(taskins) + + # Assert + if original.has_content(): + assert original.content == deserialized.content + if original.has_error(): + assert original.error == deserialized.error + assert metadata == deserialized.metadata + + +@pytest.mark.parametrize( + "content_fn, error_fn", + [ + ( + lambda maker: maker.recordset(1, 1, 1), + None, + ), # check when only content is set + (None, lambda code: Error(code=code)), # check when only error is set + ], +) +def test_message_to_and_from_taskres( + content_fn: Callable[ + [ + RecordMaker, + ], + RecordSet, + ], + error_fn: Callable[[int], Error], +) -> None: + """Test Message to and from TaskRes.""" + # Prepare + maker = RecordMaker(state=2) + metadata = maker.metadata() + metadata.dst_node_id = 0 # Assume driver node + + original = Message( + metadata=metadata, + content=None if content_fn is None else content_fn(maker), + error=None if error_fn is None else error_fn(0), + ) + + # Execute + taskres = message_to_taskres(original) + taskres.task_id = metadata.message_id + deserialized = message_from_taskres(taskres) + + # Assert + if original.has_content(): + assert original.content == deserialized.content + if original.has_error(): + assert original.error == deserialized.error + assert metadata == deserialized.metadata diff --git a/src/py/flwr/common/telemetry.py b/src/py/flwr/common/telemetry.py index fed8b5a978bc..8eb594085d31 100644 --- a/src/py/flwr/common/telemetry.py +++ b/src/py/flwr/common/telemetry.py @@ -32,7 +32,7 @@ FLWR_TELEMETRY_ENABLED = os.getenv("FLWR_TELEMETRY_ENABLED", "1") FLWR_TELEMETRY_LOGGING = os.getenv("FLWR_TELEMETRY_LOGGING", "0") -TELEMETRY_EVENTS_URL = "https://telemetry.flower.dev/api/v1/event" +TELEMETRY_EVENTS_URL = "https://telemetry.flower.ai/api/v1/event" LOGGER_NAME = "flwr-telemetry" LOGGER_LEVEL = logging.DEBUG @@ -137,8 +137,8 @@ def _generate_next_value_(name: str, start: int, count: int, last_values: List[A RUN_FLEET_API_LEAVE = auto() # Driver API and Fleet API - RUN_SERVER_ENTER = auto() - RUN_SERVER_LEAVE = auto() + RUN_SUPERLINK_ENTER = auto() + RUN_SUPERLINK_LEAVE = auto() # Simulation START_SIMULATION_ENTER = auto() @@ -152,9 +152,13 @@ def _generate_next_value_(name: str, start: int, count: int, last_values: List[A START_DRIVER_ENTER = auto() START_DRIVER_LEAVE = auto() - # SuperNode: flower-client - RUN_CLIENT_ENTER = auto() - RUN_CLIENT_LEAVE = auto() + # flower-client-app + RUN_CLIENT_APP_ENTER = auto() + RUN_CLIENT_APP_LEAVE = auto() + + # flower-server-app + RUN_SERVER_APP_ENTER = auto() + RUN_SERVER_APP_LEAVE = auto() # Use the ThreadPoolExecutor with max_workers=1 to have a queue diff --git a/src/py/flwr/common/version.py b/src/py/flwr/common/version.py index 9545b4b68073..6808c66606b1 100644 --- a/src/py/flwr/common/version.py +++ b/src/py/flwr/common/version.py @@ -1,6 +1,5 @@ """Flower package version helper.""" - import importlib.metadata as importlib_metadata from typing import Tuple diff --git a/src/py/flwr/driver/app_test.py b/src/py/flwr/driver/app_test.py deleted file mode 100644 index bfa0098f68e2..000000000000 --- a/src/py/flwr/driver/app_test.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright 2022 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Flower Driver app tests.""" - - -import threading -import time -import unittest -from unittest.mock import MagicMock - -from flwr.driver.app import update_client_manager -from flwr.proto.driver_pb2 import ( # pylint: disable=E0611 - CreateRunResponse, - GetNodesResponse, -) -from flwr.proto.node_pb2 import Node # pylint: disable=E0611 -from flwr.server.client_manager import SimpleClientManager - - -class TestClientManagerWithDriver(unittest.TestCase): - """Tests for ClientManager. - - Considering multi-threading, all tests assume that the `update_client_manager()` - updates the ClientManager every 3 seconds. - """ - - def test_simple_client_manager_update(self) -> None: - """Tests if the node update works correctly.""" - # Prepare - expected_nodes = [Node(node_id=i, anonymous=False) for i in range(100)] - expected_updated_nodes = [ - Node(node_id=i, anonymous=False) for i in range(80, 120) - ] - driver = MagicMock() - driver.stub = "driver stub" - driver.create_run.return_value = CreateRunResponse(run_id=1) - driver.get_nodes.return_value = GetNodesResponse(nodes=expected_nodes) - client_manager = SimpleClientManager() - lock = threading.Lock() - - # Execute - thread = threading.Thread( - target=update_client_manager, - args=( - driver, - client_manager, - lock, - ), - daemon=True, - ) - thread.start() - # Wait until all nodes are registered via `client_manager.sample()` - client_manager.sample(len(expected_nodes)) - # Retrieve all nodes in `client_manager` - node_ids = {proxy.node_id for proxy in client_manager.all().values()} - # Update the GetNodesResponse and wait until the `client_manager` is updated - driver.get_nodes.return_value = GetNodesResponse(nodes=expected_updated_nodes) - while True: - with lock: - if len(client_manager.all()) == len(expected_updated_nodes): - break - time.sleep(1.3) - # Retrieve all nodes in `client_manager` - updated_node_ids = {proxy.node_id for proxy in client_manager.all().values()} - # Simulate `driver.disconnect()` - driver.stub = None - - # Assert - driver.create_run.assert_called_once() - assert node_ids == {node.node_id for node in expected_nodes} - assert updated_node_ids == {node.node_id for node in expected_updated_nodes} - - # Exit - thread.join() diff --git a/src/py/flwr/driver/driver.py b/src/py/flwr/driver/driver.py deleted file mode 100644 index 512a2001165e..000000000000 --- a/src/py/flwr/driver/driver.py +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright 2022 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Flower driver service client.""" - - -from typing import Iterable, List, Optional, Tuple - -from flwr.driver.grpc_driver import DEFAULT_SERVER_ADDRESS_DRIVER, GrpcDriver -from flwr.proto.driver_pb2 import ( # pylint: disable=E0611 - CreateRunRequest, - GetNodesRequest, - PullTaskResRequest, - PushTaskInsRequest, -) -from flwr.proto.node_pb2 import Node # pylint: disable=E0611 -from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 - - -class Driver: - """`Driver` class provides an interface to the Driver API. - - Parameters - ---------- - driver_service_address : Optional[str] - The IPv4 or IPv6 address of the Driver API server. - Defaults to `"[::]:9091"`. - certificates : bytes (default: None) - Tuple containing root certificate, server certificate, and private key - to start a secure SSL-enabled server. The tuple is expected to have - three bytes elements in the following order: - - * CA certificate. - * server certificate. - * server private key. - """ - - def __init__( - self, - driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER, - certificates: Optional[bytes] = None, - ) -> None: - self.addr = driver_service_address - self.certificates = certificates - self.grpc_driver: Optional[GrpcDriver] = None - self.run_id: Optional[int] = None - self.node = Node(node_id=0, anonymous=True) - - def _get_grpc_driver_and_run_id(self) -> Tuple[GrpcDriver, int]: - # Check if the GrpcDriver is initialized - if self.grpc_driver is None or self.run_id is None: - # Connect and create run - self.grpc_driver = GrpcDriver( - driver_service_address=self.addr, certificates=self.certificates - ) - self.grpc_driver.connect() - res = self.grpc_driver.create_run(CreateRunRequest()) - self.run_id = res.run_id - - return self.grpc_driver, self.run_id - - def get_nodes(self) -> List[Node]: - """Get node IDs.""" - grpc_driver, run_id = self._get_grpc_driver_and_run_id() - - # Call GrpcDriver method - res = grpc_driver.get_nodes(GetNodesRequest(run_id=run_id)) - return list(res.nodes) - - def push_task_ins(self, task_ins_list: List[TaskIns]) -> List[str]: - """Schedule tasks.""" - grpc_driver, run_id = self._get_grpc_driver_and_run_id() - - # Set run_id - for task_ins in task_ins_list: - task_ins.run_id = run_id - - # Call GrpcDriver method - res = grpc_driver.push_task_ins(PushTaskInsRequest(task_ins_list=task_ins_list)) - return list(res.task_ids) - - def pull_task_res(self, task_ids: Iterable[str]) -> List[TaskRes]: - """Get task results.""" - grpc_driver, _ = self._get_grpc_driver_and_run_id() - - # Call GrpcDriver method - res = grpc_driver.pull_task_res( - PullTaskResRequest(node=self.node, task_ids=task_ids) - ) - return list(res.task_res_list) - - def __del__(self) -> None: - """Disconnect GrpcDriver if connected.""" - # Check if GrpcDriver is initialized - if self.grpc_driver is None: - return - - # Disconnect - self.grpc_driver.disconnect() diff --git a/src/py/flwr/proto/error_pb2.py b/src/py/flwr/proto/error_pb2.py new file mode 100644 index 000000000000..41721ae08804 --- /dev/null +++ b/src/py/flwr/proto/error_pb2.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: flwr/proto/error.proto +# Protobuf Python Version: 4.25.0 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/error.proto\x12\nflwr.proto\"%\n\x05\x45rror\x12\x0c\n\x04\x63ode\x18\x01 \x01(\x12\x12\x0e\n\x06reason\x18\x02 \x01(\tb\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.error_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + _globals['_ERROR']._serialized_start=38 + _globals['_ERROR']._serialized_end=75 +# @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/error_pb2.pyi b/src/py/flwr/proto/error_pb2.pyi new file mode 100644 index 000000000000..1811e5aa0ca8 --- /dev/null +++ b/src/py/flwr/proto/error_pb2.pyi @@ -0,0 +1,25 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +""" +import builtins +import google.protobuf.descriptor +import google.protobuf.message +import typing +import typing_extensions + +DESCRIPTOR: google.protobuf.descriptor.FileDescriptor + +class Error(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + CODE_FIELD_NUMBER: builtins.int + REASON_FIELD_NUMBER: builtins.int + code: builtins.int + reason: typing.Text + def __init__(self, + *, + code: builtins.int = ..., + reason: typing.Text = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["code",b"code","reason",b"reason"]) -> None: ... +global___Error = Error diff --git a/src/py/flwr/proto/error_pb2_grpc.py b/src/py/flwr/proto/error_pb2_grpc.py new file mode 100644 index 000000000000..2daafffebfc8 --- /dev/null +++ b/src/py/flwr/proto/error_pb2_grpc.py @@ -0,0 +1,4 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + diff --git a/src/py/flwr/proto/error_pb2_grpc.pyi b/src/py/flwr/proto/error_pb2_grpc.pyi new file mode 100644 index 000000000000..f3a5a087ef5d --- /dev/null +++ b/src/py/flwr/proto/error_pb2_grpc.pyi @@ -0,0 +1,4 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +""" diff --git a/src/py/flwr/proto/task_pb2.py b/src/py/flwr/proto/task_pb2.py index f9b2180b15dd..4d5f863e88dd 100644 --- a/src/py/flwr/proto/task_pb2.py +++ b/src/py/flwr/proto/task_pb2.py @@ -15,33 +15,20 @@ from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2 from flwr.proto import recordset_pb2 as flwr_dot_proto_dot_recordset__pb2 from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2 +from flwr.proto import error_pb2 as flwr_dot_proto_dot_error__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xff\x02\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\t\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x0b\n\x03ttl\x18\x05 \x01(\t\x12\x10\n\x08\x61ncestry\x18\x06 \x03(\t\x12\x11\n\ttask_type\x18\x07 \x01(\t\x12(\n\trecordset\x18\x08 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12<\n\x15legacy_server_message\x18\x65 \x01(\x0b\x32\x19.flwr.proto.ServerMessageB\x02\x18\x01\x12<\n\x15legacy_client_message\x18\x66 \x01(\x0b\x32\x19.flwr.proto.ClientMessageB\x02\x18\x01\x12-\n\x02sa\x18g \x01(\x0b\x32\x1d.flwr.proto.SecureAggregationB\x02\x18\x01\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\xcc\x02\n\x05Value\x12\x10\n\x06\x64ouble\x18\x01 \x01(\x01H\x00\x12\x10\n\x06sint64\x18\x02 \x01(\x12H\x00\x12\x0e\n\x04\x62ool\x18\x03 \x01(\x08H\x00\x12\x10\n\x06string\x18\x04 \x01(\tH\x00\x12\x0f\n\x05\x62ytes\x18\x05 \x01(\x0cH\x00\x12-\n\x0b\x64ouble_list\x18\x15 \x01(\x0b\x32\x16.flwr.proto.DoubleListH\x00\x12-\n\x0bsint64_list\x18\x16 \x01(\x0b\x32\x16.flwr.proto.Sint64ListH\x00\x12)\n\tbool_list\x18\x17 \x01(\x0b\x32\x14.flwr.proto.BoolListH\x00\x12-\n\x0bstring_list\x18\x18 \x01(\x0b\x32\x16.flwr.proto.StringListH\x00\x12+\n\nbytes_list\x18\x19 \x01(\x0b\x32\x15.flwr.proto.BytesListH\x00\x42\x07\n\x05value\"\xa0\x01\n\x11SecureAggregation\x12\x44\n\x0cnamed_values\x18\x01 \x03(\x0b\x32..flwr.proto.SecureAggregation.NamedValuesEntry\x1a\x45\n\x10NamedValuesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12 \n\x05value\x18\x02 \x01(\x0b\x32\x11.flwr.proto.Value:\x02\x38\x01\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x16\x66lwr/proto/error.proto\"\xf6\x01\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\t\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x0b\n\x03ttl\x18\x05 \x01(\t\x12\x10\n\x08\x61ncestry\x18\x06 \x03(\t\x12\x11\n\ttask_type\x18\x07 \x01(\t\x12(\n\trecordset\x18\x08 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\t \x01(\x0b\x32\x11.flwr.proto.Error\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Taskb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.task_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _globals['_TASK'].fields_by_name['legacy_server_message']._options = None - _globals['_TASK'].fields_by_name['legacy_server_message']._serialized_options = b'\030\001' - _globals['_TASK'].fields_by_name['legacy_client_message']._options = None - _globals['_TASK'].fields_by_name['legacy_client_message']._serialized_options = b'\030\001' - _globals['_TASK'].fields_by_name['sa']._options = None - _globals['_TASK'].fields_by_name['sa']._serialized_options = b'\030\001' - _globals['_SECUREAGGREGATION_NAMEDVALUESENTRY']._options = None - _globals['_SECUREAGGREGATION_NAMEDVALUESENTRY']._serialized_options = b'8\001' - _globals['_TASK']._serialized_start=117 - _globals['_TASK']._serialized_end=500 - _globals['_TASKINS']._serialized_start=502 - _globals['_TASKINS']._serialized_end=594 - _globals['_TASKRES']._serialized_start=596 - _globals['_TASKRES']._serialized_end=688 - _globals['_VALUE']._serialized_start=691 - _globals['_VALUE']._serialized_end=1023 - _globals['_SECUREAGGREGATION']._serialized_start=1026 - _globals['_SECUREAGGREGATION']._serialized_end=1186 - _globals['_SECUREAGGREGATION_NAMEDVALUESENTRY']._serialized_start=1117 - _globals['_SECUREAGGREGATION_NAMEDVALUESENTRY']._serialized_end=1186 + _globals['_TASK']._serialized_start=141 + _globals['_TASK']._serialized_end=387 + _globals['_TASKINS']._serialized_start=389 + _globals['_TASKINS']._serialized_end=481 + _globals['_TASKRES']._serialized_start=483 + _globals['_TASKRES']._serialized_end=575 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/task_pb2.pyi b/src/py/flwr/proto/task_pb2.pyi index 39119797c9e4..b9c10139cfb3 100644 --- a/src/py/flwr/proto/task_pb2.pyi +++ b/src/py/flwr/proto/task_pb2.pyi @@ -3,9 +3,9 @@ isort:skip_file """ import builtins +import flwr.proto.error_pb2 import flwr.proto.node_pb2 import flwr.proto.recordset_pb2 -import flwr.proto.transport_pb2 import google.protobuf.descriptor import google.protobuf.internal.containers import google.protobuf.message @@ -24,9 +24,7 @@ class Task(google.protobuf.message.Message): ANCESTRY_FIELD_NUMBER: builtins.int TASK_TYPE_FIELD_NUMBER: builtins.int RECORDSET_FIELD_NUMBER: builtins.int - LEGACY_SERVER_MESSAGE_FIELD_NUMBER: builtins.int - LEGACY_CLIENT_MESSAGE_FIELD_NUMBER: builtins.int - SA_FIELD_NUMBER: builtins.int + ERROR_FIELD_NUMBER: builtins.int @property def producer(self) -> flwr.proto.node_pb2.Node: ... @property @@ -40,11 +38,7 @@ class Task(google.protobuf.message.Message): @property def recordset(self) -> flwr.proto.recordset_pb2.RecordSet: ... @property - def legacy_server_message(self) -> flwr.proto.transport_pb2.ServerMessage: ... - @property - def legacy_client_message(self) -> flwr.proto.transport_pb2.ClientMessage: ... - @property - def sa(self) -> global___SecureAggregation: ... + def error(self) -> flwr.proto.error_pb2.Error: ... def __init__(self, *, producer: typing.Optional[flwr.proto.node_pb2.Node] = ..., @@ -55,12 +49,10 @@ class Task(google.protobuf.message.Message): ancestry: typing.Optional[typing.Iterable[typing.Text]] = ..., task_type: typing.Text = ..., recordset: typing.Optional[flwr.proto.recordset_pb2.RecordSet] = ..., - legacy_server_message: typing.Optional[flwr.proto.transport_pb2.ServerMessage] = ..., - legacy_client_message: typing.Optional[flwr.proto.transport_pb2.ClientMessage] = ..., - sa: typing.Optional[global___SecureAggregation] = ..., + error: typing.Optional[flwr.proto.error_pb2.Error] = ..., ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["consumer",b"consumer","legacy_client_message",b"legacy_client_message","legacy_server_message",b"legacy_server_message","producer",b"producer","recordset",b"recordset","sa",b"sa"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["ancestry",b"ancestry","consumer",b"consumer","created_at",b"created_at","delivered_at",b"delivered_at","legacy_client_message",b"legacy_client_message","legacy_server_message",b"legacy_server_message","producer",b"producer","recordset",b"recordset","sa",b"sa","task_type",b"task_type","ttl",b"ttl"]) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["consumer",b"consumer","error",b"error","producer",b"producer","recordset",b"recordset"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["ancestry",b"ancestry","consumer",b"consumer","created_at",b"created_at","delivered_at",b"delivered_at","error",b"error","producer",b"producer","recordset",b"recordset","task_type",b"task_type","ttl",b"ttl"]) -> None: ... global___Task = Task class TaskIns(google.protobuf.message.Message): @@ -106,79 +98,3 @@ class TaskRes(google.protobuf.message.Message): def HasField(self, field_name: typing_extensions.Literal["task",b"task"]) -> builtins.bool: ... def ClearField(self, field_name: typing_extensions.Literal["group_id",b"group_id","run_id",b"run_id","task",b"task","task_id",b"task_id"]) -> None: ... global___TaskRes = TaskRes - -class Value(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - DOUBLE_FIELD_NUMBER: builtins.int - SINT64_FIELD_NUMBER: builtins.int - BOOL_FIELD_NUMBER: builtins.int - STRING_FIELD_NUMBER: builtins.int - BYTES_FIELD_NUMBER: builtins.int - DOUBLE_LIST_FIELD_NUMBER: builtins.int - SINT64_LIST_FIELD_NUMBER: builtins.int - BOOL_LIST_FIELD_NUMBER: builtins.int - STRING_LIST_FIELD_NUMBER: builtins.int - BYTES_LIST_FIELD_NUMBER: builtins.int - double: builtins.float - """Single element""" - - sint64: builtins.int - bool: builtins.bool - string: typing.Text - bytes: builtins.bytes - @property - def double_list(self) -> flwr.proto.recordset_pb2.DoubleList: - """List types""" - pass - @property - def sint64_list(self) -> flwr.proto.recordset_pb2.Sint64List: ... - @property - def bool_list(self) -> flwr.proto.recordset_pb2.BoolList: ... - @property - def string_list(self) -> flwr.proto.recordset_pb2.StringList: ... - @property - def bytes_list(self) -> flwr.proto.recordset_pb2.BytesList: ... - def __init__(self, - *, - double: builtins.float = ..., - sint64: builtins.int = ..., - bool: builtins.bool = ..., - string: typing.Text = ..., - bytes: builtins.bytes = ..., - double_list: typing.Optional[flwr.proto.recordset_pb2.DoubleList] = ..., - sint64_list: typing.Optional[flwr.proto.recordset_pb2.Sint64List] = ..., - bool_list: typing.Optional[flwr.proto.recordset_pb2.BoolList] = ..., - string_list: typing.Optional[flwr.proto.recordset_pb2.StringList] = ..., - bytes_list: typing.Optional[flwr.proto.recordset_pb2.BytesList] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["bool",b"bool","bool_list",b"bool_list","bytes",b"bytes","bytes_list",b"bytes_list","double",b"double","double_list",b"double_list","sint64",b"sint64","sint64_list",b"sint64_list","string",b"string","string_list",b"string_list","value",b"value"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["bool",b"bool","bool_list",b"bool_list","bytes",b"bytes","bytes_list",b"bytes_list","double",b"double","double_list",b"double_list","sint64",b"sint64","sint64_list",b"sint64_list","string",b"string","string_list",b"string_list","value",b"value"]) -> None: ... - def WhichOneof(self, oneof_group: typing_extensions.Literal["value",b"value"]) -> typing.Optional[typing_extensions.Literal["double","sint64","bool","string","bytes","double_list","sint64_list","bool_list","string_list","bytes_list"]]: ... -global___Value = Value - -class SecureAggregation(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - class NamedValuesEntry(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - KEY_FIELD_NUMBER: builtins.int - VALUE_FIELD_NUMBER: builtins.int - key: typing.Text - @property - def value(self) -> global___Value: ... - def __init__(self, - *, - key: typing.Text = ..., - value: typing.Optional[global___Value] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... - - NAMED_VALUES_FIELD_NUMBER: builtins.int - @property - def named_values(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, global___Value]: ... - def __init__(self, - *, - named_values: typing.Optional[typing.Mapping[typing.Text, global___Value]] = ..., - ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["named_values",b"named_values"]) -> None: ... -global___SecureAggregation = SecureAggregation diff --git a/src/py/flwr/server/__init__.py b/src/py/flwr/server/__init__.py index 74abe8dd463c..633bd668b520 100644 --- a/src/py/flwr/server/__init__.py +++ b/src/py/flwr/server/__init__.py @@ -16,25 +16,37 @@ from . import strategy -from .app import ServerConfig as ServerConfig +from . import workflow as workflow from .app import run_driver_api as run_driver_api from .app import run_fleet_api as run_fleet_api -from .app import run_server as run_server +from .app import run_superlink as run_superlink from .app import start_server as start_server from .client_manager import ClientManager as ClientManager from .client_manager import SimpleClientManager as SimpleClientManager +from .compat import LegacyContext as LegacyContext +from .compat import start_driver as start_driver +from .driver import Driver as Driver from .history import History as History +from .run_serverapp import run_server_app as run_server_app from .server import Server as Server +from .server_app import ServerApp as ServerApp +from .server_config import ServerConfig as ServerConfig __all__ = [ "ClientManager", + "Driver", "History", + "LegacyContext", "run_driver_api", "run_fleet_api", - "run_server", + "run_server_app", + "run_superlink", "Server", + "ServerApp", "ServerConfig", "SimpleClientManager", + "start_driver", "start_server", "strategy", + "workflow", ] diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index 636207e7a859..e04cfb37e118 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -14,17 +14,14 @@ # ============================================================================== """Flower server app.""" - import argparse +import asyncio import importlib.util import sys import threading -from dataclasses import dataclass from logging import ERROR, INFO, WARN from os.path import isfile from pathlib import Path -from signal import SIGINT, SIGTERM, signal -from types import FrameType from typing import List, Optional, Tuple import grpc @@ -33,33 +30,29 @@ from flwr.common.address import parse_address from flwr.common.constant import ( MISSING_EXTRA_REST, - TRANSPORT_TYPE_GRPC_BIDI, TRANSPORT_TYPE_GRPC_RERE, TRANSPORT_TYPE_REST, + TRANSPORT_TYPE_VCE, ) +from flwr.common.exit_handlers import register_exit_handlers from flwr.common.logger import log -from flwr.proto.driver_pb2_grpc import ( # pylint: disable=E0611 - add_DriverServicer_to_server, -) from flwr.proto.fleet_pb2_grpc import ( # pylint: disable=E0611 add_FleetServicer_to_server, ) -from flwr.proto.transport_pb2_grpc import ( # pylint: disable=E0611 - add_FlowerServiceServicer_to_server, -) -from flwr.server.client_manager import ClientManager, SimpleClientManager -from flwr.server.driver.driver_servicer import DriverServicer -from flwr.server.fleet.grpc_bidi.driver_client_manager import DriverClientManager -from flwr.server.fleet.grpc_bidi.flower_service_servicer import FlowerServiceServicer -from flwr.server.fleet.grpc_bidi.grpc_server import ( + +from .client_manager import ClientManager +from .history import History +from .server import Server, init_defaults, run_fl +from .server_config import ServerConfig +from .strategy import Strategy +from .superlink.driver.driver_grpc import run_driver_api_grpc +from .superlink.fleet.grpc_bidi.grpc_server import ( generic_create_grpc_server, start_grpc_server, ) -from flwr.server.fleet.grpc_rere.fleet_servicer import FleetServicer -from flwr.server.history import History -from flwr.server.server import Server -from flwr.server.state import StateFactory -from flwr.server.strategy import FedAvg, Strategy +from .superlink.fleet.grpc_rere.fleet_servicer import FleetServicer +from .superlink.fleet.vce import start_vce +from .superlink.state import StateFactory ADDRESS_DRIVER_API = "0.0.0.0:9091" ADDRESS_FLEET_API_GRPC_RERE = "0.0.0.0:9092" @@ -69,18 +62,6 @@ DATABASE = ":flwr-in-memory-state:" -@dataclass -class ServerConfig: - """Flower server config. - - All attributes have default values which allows users to configure just the ones - they care about. - """ - - num_rounds: int = 1 - round_timeout: Optional[float] = None - - def start_server( # pylint: disable=too-many-arguments,too-many-locals *, server_address: str = ADDRESS_FLEET_API_GRPC_BIDI, @@ -200,52 +181,11 @@ def start_server( # pylint: disable=too-many-arguments,too-many-locals return hist -def init_defaults( - server: Optional[Server], - config: Optional[ServerConfig], - strategy: Optional[Strategy], - client_manager: Optional[ClientManager], -) -> Tuple[Server, ServerConfig]: - """Create server instance if none was given.""" - if server is None: - if client_manager is None: - client_manager = SimpleClientManager() - if strategy is None: - strategy = FedAvg() - server = Server(client_manager=client_manager, strategy=strategy) - elif strategy is not None: - log(WARN, "Both server and strategy were provided, ignoring strategy") - - # Set default config values - if config is None: - config = ServerConfig() - - return server, config - - -def run_fl( - server: Server, - config: ServerConfig, -) -> History: - """Train a model on the given server and return the History object.""" - hist = server.fit(num_rounds=config.num_rounds, timeout=config.round_timeout) - log(INFO, "app_fit: losses_distributed %s", str(hist.losses_distributed)) - log(INFO, "app_fit: metrics_distributed_fit %s", str(hist.metrics_distributed_fit)) - log(INFO, "app_fit: metrics_distributed %s", str(hist.metrics_distributed)) - log(INFO, "app_fit: losses_centralized %s", str(hist.losses_centralized)) - log(INFO, "app_fit: metrics_centralized %s", str(hist.metrics_centralized)) - - # Graceful shutdown - server.disconnect_all_clients(timeout=config.round_timeout) - - return hist - - def run_driver_api() -> None: """Run Flower server (Driver API).""" log(INFO, "Starting Flower server (Driver API)") event(EventType.RUN_DRIVER_API_ENTER) - args = _parse_args_driver().parse_args() + args = _parse_args_run_driver_api().parse_args() # Parse IP address parsed_address = parse_address(args.driver_api_address) @@ -261,17 +201,17 @@ def run_driver_api() -> None: state_factory = StateFactory(args.database) # Start server - grpc_server: grpc.Server = _run_driver_api_grpc( + grpc_server: grpc.Server = run_driver_api_grpc( address=address, state_factory=state_factory, certificates=certificates, ) # Graceful shutdown - _register_exit_handlers( + register_exit_handlers( + event_type=EventType.RUN_DRIVER_API_LEAVE, grpc_servers=[grpc_server], bckg_threads=[], - event_type=EventType.RUN_DRIVER_API_LEAVE, ) # Block @@ -282,7 +222,7 @@ def run_fleet_api() -> None: """Run Flower server (Fleet API).""" log(INFO, "Starting Flower server (Fleet API)") event(EventType.RUN_FLEET_API_ENTER) - args = _parse_args_fleet().parse_args() + args = _parse_args_run_fleet_api().parse_args() # Obtain certificates certificates = _try_obtain_certificates(args) @@ -319,19 +259,6 @@ def run_fleet_api() -> None: ) fleet_thread.start() bckg_threads.append(fleet_thread) - elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_BIDI: - address_arg = args.grpc_fleet_api_address - parsed_address = parse_address(address_arg) - if not parsed_address: - sys.exit(f"Fleet IP address ({address_arg}) cannot be parsed.") - host, port, is_v6 = parsed_address - address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}" - fleet_server = _run_fleet_api_grpc_bidi( - address=address, - state_factory=state_factory, - certificates=certificates, - ) - grpc_servers.append(fleet_server) elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE: address_arg = args.grpc_rere_fleet_api_address parsed_address = parse_address(address_arg) @@ -349,10 +276,10 @@ def run_fleet_api() -> None: raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}") # Graceful shutdown - _register_exit_handlers( + register_exit_handlers( + event_type=EventType.RUN_FLEET_API_LEAVE, grpc_servers=grpc_servers, bckg_threads=bckg_threads, - event_type=EventType.RUN_FLEET_API_LEAVE, ) # Block @@ -363,11 +290,11 @@ def run_fleet_api() -> None: # pylint: disable=too-many-branches, too-many-locals, too-many-statements -def run_server() -> None: +def run_superlink() -> None: """Run Flower server (Driver API and Fleet API).""" log(INFO, "Starting Flower server") - event(EventType.RUN_SERVER_ENTER) - args = _parse_args_server().parse_args() + event(EventType.RUN_SUPERLINK_ENTER) + args = _parse_args_run_superlink().parse_args() # Parse IP address parsed_address = parse_address(args.driver_api_address) @@ -383,7 +310,7 @@ def run_server() -> None: state_factory = StateFactory(args.database) # Start Driver API - driver_server: grpc.Server = _run_driver_api_grpc( + driver_server: grpc.Server = run_driver_api_grpc( address=address, state_factory=state_factory, certificates=certificates, @@ -418,19 +345,6 @@ def run_server() -> None: ) fleet_thread.start() bckg_threads.append(fleet_thread) - elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_BIDI: - address_arg = args.grpc_bidi_fleet_api_address - parsed_address = parse_address(address_arg) - if not parsed_address: - sys.exit(f"Fleet IP address ({address_arg}) cannot be parsed.") - host, port, is_v6 = parsed_address - address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}" - fleet_server = _run_fleet_api_grpc_bidi( - address=address, - state_factory=state_factory, - certificates=certificates, - ) - grpc_servers.append(fleet_server) elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE: address_arg = args.grpc_rere_fleet_api_address parsed_address = parse_address(address_arg) @@ -444,14 +358,25 @@ def run_server() -> None: certificates=certificates, ) grpc_servers.append(fleet_server) + elif args.fleet_api_type == TRANSPORT_TYPE_VCE: + f_stop = asyncio.Event() # Does nothing + _run_fleet_api_vce( + num_supernodes=args.num_supernodes, + client_app_attr=args.client_app, + backend_name=args.backend, + backend_config_json_stream=args.backend_config, + app_dir=args.app_dir, + state_factory=state_factory, + f_stop=f_stop, + ) else: raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}") # Graceful shutdown - _register_exit_handlers( + register_exit_handlers( + event_type=EventType.RUN_SUPERLINK_LEAVE, grpc_servers=grpc_servers, bckg_threads=bckg_threads, - event_type=EventType.RUN_SERVER_LEAVE, ) # Block @@ -486,105 +411,6 @@ def _try_obtain_certificates( return certificates -def _register_exit_handlers( - grpc_servers: List[grpc.Server], - bckg_threads: List[threading.Thread], - event_type: EventType, -) -> None: - default_handlers = { - SIGINT: None, - SIGTERM: None, - } - - def graceful_exit_handler( # type: ignore - signalnum, - frame: FrameType, # pylint: disable=unused-argument - ) -> None: - """Exit handler to be registered with signal.signal. - - When called will reset signal handler to original signal handler from - default_handlers. - """ - # Reset to default handler - signal(signalnum, default_handlers[signalnum]) - - event_res = event(event_type=event_type) - - for grpc_server in grpc_servers: - grpc_server.stop(grace=1) - - for bckg_thread in bckg_threads: - bckg_thread.join() - - # Ensure event has happend - event_res.result() - - # Setup things for graceful exit - sys.exit(0) - - default_handlers[SIGINT] = signal( # type: ignore - SIGINT, - graceful_exit_handler, # type: ignore - ) - default_handlers[SIGTERM] = signal( # type: ignore - SIGTERM, - graceful_exit_handler, # type: ignore - ) - - -def _run_driver_api_grpc( - address: str, - state_factory: StateFactory, - certificates: Optional[Tuple[bytes, bytes, bytes]], -) -> grpc.Server: - """Run Driver API (gRPC, request-response).""" - # Create Driver API gRPC server - driver_servicer: grpc.Server = DriverServicer( - state_factory=state_factory, - ) - driver_add_servicer_to_server_fn = add_DriverServicer_to_server - driver_grpc_server = generic_create_grpc_server( - servicer_and_add_fn=(driver_servicer, driver_add_servicer_to_server_fn), - server_address=address, - max_message_length=GRPC_MAX_MESSAGE_LENGTH, - certificates=certificates, - ) - - log(INFO, "Flower ECE: Starting Driver API (gRPC-rere) on %s", address) - driver_grpc_server.start() - - return driver_grpc_server - - -def _run_fleet_api_grpc_bidi( - address: str, - state_factory: StateFactory, - certificates: Optional[Tuple[bytes, bytes, bytes]], -) -> grpc.Server: - """Run Fleet API (gRPC, bidirectional streaming).""" - # DriverClientManager - driver_client_manager = DriverClientManager( - state_factory=state_factory, - ) - - # Create (legacy) Fleet API gRPC server - fleet_servicer = FlowerServiceServicer( - client_manager=driver_client_manager, - ) - fleet_add_servicer_to_server_fn = add_FlowerServiceServicer_to_server - fleet_grpc_server = generic_create_grpc_server( - servicer_and_add_fn=(fleet_servicer, fleet_add_servicer_to_server_fn), - server_address=address, - max_message_length=GRPC_MAX_MESSAGE_LENGTH, - certificates=certificates, - ) - - log(INFO, "Flower ECE: Starting Fleet API (gRPC-bidi) on %s", address) - fleet_grpc_server.start() - - return fleet_grpc_server - - def _run_fleet_api_grpc_rere( address: str, state_factory: StateFactory, @@ -593,7 +419,7 @@ def _run_fleet_api_grpc_rere( """Run Fleet API (gRPC, request-response).""" # Create Fleet API gRPC server fleet_servicer = FleetServicer( - state=state_factory.state(), + state_factory=state_factory, ) fleet_add_servicer_to_server_fn = add_FleetServicer_to_server fleet_grpc_server = generic_create_grpc_server( @@ -609,6 +435,29 @@ def _run_fleet_api_grpc_rere( return fleet_grpc_server +# pylint: disable=too-many-arguments +def _run_fleet_api_vce( + num_supernodes: int, + client_app_attr: str, + backend_name: str, + backend_config_json_stream: str, + app_dir: str, + state_factory: StateFactory, + f_stop: asyncio.Event, +) -> None: + log(INFO, "Flower VCE: Starting Fleet API (VirtualClientEngine)") + + start_vce( + num_supernodes=num_supernodes, + client_app_attr=client_app_attr, + backend_name=backend_name, + backend_config_json_stream=backend_config_json_stream, + state_factory=state_factory, + app_dir=app_dir, + f_stop=f_stop, + ) + + # pylint: disable=import-outside-toplevel,too-many-arguments def _run_fleet_api_rest( host: str, @@ -622,7 +471,7 @@ def _run_fleet_api_rest( try: import uvicorn - from flwr.server.fleet.rest_rere.rest_api import app as fast_api_app + from flwr.server.superlink.fleet.rest_rere.rest_api import app as fast_api_app except ModuleNotFoundError: sys.exit(MISSING_EXTRA_REST) if workers != 1: @@ -645,7 +494,7 @@ def _run_fleet_api_rest( raise ValueError(validation_exceptions) uvicorn.run( - app="flwr.server.fleet.rest_rere.rest_api:app", + app="flwr.server.superlink.fleet.rest_rere.rest_api:app", port=port, host=host, reload=False, @@ -682,7 +531,7 @@ def _validate_ssl_files( return validation_exceptions -def _parse_args_driver() -> argparse.ArgumentParser: +def _parse_args_run_driver_api() -> argparse.ArgumentParser: """Parse command line arguments for Driver API.""" parser = argparse.ArgumentParser( description="Start a Flower Driver API server. " @@ -699,7 +548,7 @@ def _parse_args_driver() -> argparse.ArgumentParser: return parser -def _parse_args_fleet() -> argparse.ArgumentParser: +def _parse_args_run_fleet_api() -> argparse.ArgumentParser: """Parse command line arguments for Fleet API.""" parser = argparse.ArgumentParser( description="Start a Flower Fleet API server." @@ -716,7 +565,7 @@ def _parse_args_fleet() -> argparse.ArgumentParser: return parser -def _parse_args_server() -> argparse.ArgumentParser: +def _parse_args_run_superlink() -> argparse.ArgumentParser: """Parse command line arguments for both Driver API and Fleet API.""" parser = argparse.ArgumentParser( description="This will start a Flower server " @@ -785,12 +634,13 @@ def _add_args_fleet_api(parser: argparse.ArgumentParser) -> None: const=TRANSPORT_TYPE_REST, help="Start a Fleet API server (REST, experimental)", ) + ex_group.add_argument( - "--grpc-bidi", + "--vce", action="store_const", dest="fleet_api_type", - const=TRANSPORT_TYPE_GRPC_BIDI, - help="Start a Fleet API server (gRPC-bidi)", + const=TRANSPORT_TYPE_VCE, + help="Start a Fleet API server (VirtualClientEngine)", ) # Fleet API gRPC-rere options @@ -829,12 +679,35 @@ def _add_args_fleet_api(parser: argparse.ArgumentParser) -> None: default=1, ) - # Fleet API gRPC-bidi options - grpc_bidi_group = parser.add_argument_group( - "Fleet API (gRPC-bidi) server options", "" + # Fleet API VCE options + vce_group = parser.add_argument_group("Fleet API (VCE) server options", "") + vce_group.add_argument( + "--client-app", + help="For example: `client:app` or `project.package.module:wrapper.app`.", ) - grpc_bidi_group.add_argument( - "--grpc-bidi-fleet-api-address", - help="Fleet API (gRPC-bidi) server address (IPv4, IPv6, or a domain name)", - default=ADDRESS_FLEET_API_GRPC_RERE, + vce_group.add_argument( + "--num-supernodes", + type=int, + help="Number of simulated SuperNodes.", + ) + vce_group.add_argument( + "--backend", + default="ray", + type=str, + help="Simulation backend that executes the ClientApp.", + ) + vce_group.add_argument( + "--backend-config", + type=str, + default='{"client_resources": {"num_cpus":1, "num_gpus":0.0}, "tensorflow": 0}', + help='A JSON formatted stream, e.g \'{"":, "":}\' to ' + "configure a backend. Values supported in are those included by " + "`flwr.common.typing.ConfigsRecordValues`. ", + ) + parser.add_argument( + "--app-dir", + default="", + help="Add specified directory to the PYTHONPATH and load" + "ClientApp from there." + " Default: current working directory.", ) diff --git a/src/py/flwr/server/client_manager_test.py b/src/py/flwr/server/client_manager_test.py index 8145b9b2ab7f..5820881b6aad 100644 --- a/src/py/flwr/server/client_manager_test.py +++ b/src/py/flwr/server/client_manager_test.py @@ -18,7 +18,7 @@ from unittest.mock import MagicMock from flwr.server.client_manager import SimpleClientManager -from flwr.server.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy +from flwr.server.superlink.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy def test_simple_client_manager_register() -> None: diff --git a/src/py/flwr/server/client_proxy.py b/src/py/flwr/server/client_proxy.py index 7d0547be304b..951e4ae992da 100644 --- a/src/py/flwr/server/client_proxy.py +++ b/src/py/flwr/server/client_proxy.py @@ -47,6 +47,7 @@ def get_properties( self, ins: GetPropertiesIns, timeout: Optional[float], + group_id: Optional[int], ) -> GetPropertiesRes: """Return the client's properties.""" @@ -55,6 +56,7 @@ def get_parameters( self, ins: GetParametersIns, timeout: Optional[float], + group_id: Optional[int], ) -> GetParametersRes: """Return the current local model parameters.""" @@ -63,6 +65,7 @@ def fit( self, ins: FitIns, timeout: Optional[float], + group_id: Optional[int], ) -> FitRes: """Refine the provided parameters using the locally held dataset.""" @@ -71,6 +74,7 @@ def evaluate( self, ins: EvaluateIns, timeout: Optional[float], + group_id: Optional[int], ) -> EvaluateRes: """Evaluate the provided parameters using the locally held dataset.""" @@ -79,5 +83,6 @@ def reconnect( self, ins: ReconnectIns, timeout: Optional[float], + group_id: Optional[int], ) -> DisconnectRes: """Disconnect and (optionally) reconnect later.""" diff --git a/src/py/flwr/server/client_proxy_test.py b/src/py/flwr/server/client_proxy_test.py index 266e4cbeb266..685698558e3a 100644 --- a/src/py/flwr/server/client_proxy_test.py +++ b/src/py/flwr/server/client_proxy_test.py @@ -42,6 +42,7 @@ def get_properties( self, ins: GetPropertiesIns, timeout: Optional[float], + group_id: Optional[int], ) -> GetPropertiesRes: """Return the client's properties.""" return GetPropertiesRes(status=Status(code=Code.OK, message=""), properties={}) @@ -50,6 +51,7 @@ def get_parameters( self, ins: GetParametersIns, timeout: Optional[float], + group_id: Optional[int], ) -> GetParametersRes: """Return the current local model parameters.""" return GetParametersRes( @@ -61,6 +63,7 @@ def fit( self, ins: FitIns, timeout: Optional[float], + group_id: Optional[int], ) -> FitRes: """Refine the provided weights using the locally held dataset.""" return FitRes( @@ -74,6 +77,7 @@ def evaluate( self, ins: EvaluateIns, timeout: Optional[float], + group_id: Optional[int], ) -> EvaluateRes: """Evaluate the provided weights using the locally held dataset.""" return EvaluateRes( @@ -84,6 +88,7 @@ def reconnect( self, ins: ReconnectIns, timeout: Optional[float], + group_id: Optional[int], ) -> DisconnectRes: """Disconnect and (optionally) reconnect later.""" return DisconnectRes(reason="") diff --git a/src/py/flwr/flower/__init__.py b/src/py/flwr/server/compat/__init__.py similarity index 72% rename from src/py/flwr/flower/__init__.py rename to src/py/flwr/server/compat/__init__.py index 892a7ce5afdc..7bae196ddb65 100644 --- a/src/py/flwr/flower/__init__.py +++ b/src/py/flwr/server/compat/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 Adap GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,15 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Flower callable package.""" +"""Flower ServerApp compatibility package.""" -from flwr.client.flower import Flower as Flower -from flwr.client.typing import Bwd as Bwd -from flwr.client.typing import Fwd as Fwd +from .app import start_driver as start_driver +from .legacy_context import LegacyContext as LegacyContext __all__ = [ - "Flower", - "Fwd", - "Bwd", + "LegacyContext", + "start_driver", ] diff --git a/src/py/flwr/driver/app.py b/src/py/flwr/server/compat/app.py similarity index 52% rename from src/py/flwr/driver/app.py rename to src/py/flwr/server/compat/app.py index 4fa1ad8b5c02..c13a713b5f2c 100644 --- a/src/py/flwr/driver/app.py +++ b/src/py/flwr/server/compat/app.py @@ -16,24 +16,21 @@ import sys -import threading -import time from logging import INFO from pathlib import Path -from typing import Dict, Optional, Union +from typing import Optional, Union from flwr.common import EventType, event from flwr.common.address import parse_address -from flwr.common.logger import log -from flwr.proto import driver_pb2 # pylint: disable=E0611 -from flwr.server.app import ServerConfig, init_defaults, run_fl +from flwr.common.logger import log, warn_deprecated_feature from flwr.server.client_manager import ClientManager from flwr.server.history import History -from flwr.server.server import Server +from flwr.server.server import Server, init_defaults, run_fl +from flwr.server.server_config import ServerConfig from flwr.server.strategy import Strategy -from .driver_client_proxy import DriverClientProxy -from .grpc_driver import GrpcDriver +from ..driver import Driver +from .app_utils import start_update_client_manager_thread DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091" @@ -53,6 +50,7 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals strategy: Optional[Strategy] = None, client_manager: Optional[ClientManager] = None, root_certificates: Optional[Union[bytes, str]] = None, + driver: Optional[Driver] = None, ) -> History: """Start a Flower Driver API server. @@ -72,15 +70,16 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals An implementation of the abstract base class `flwr.server.strategy.Strategy`. If no strategy is provided, then `start_server` will use `flwr.server.strategy.FedAvg`. - client_manager : Optional[flwr.driver.DriverClientManager] (default: None) - An implementation of the class - `flwr.driver.driver_client_manager.DriverClientManager`. If no + client_manager : Optional[flwr.server.ClientManager] (default: None) + An implementation of the class `flwr.server.ClientManager`. If no implementation is provided, then `start_driver` will use - `flwr.driver.driver_client_manager.DriverClientManager`. + `flwr.server.SimpleClientManager`. root_certificates : Optional[Union[bytes, str]] (default: None) The PEM-encoded root certificates as a byte string or a path string. If provided, a secure connection using the certificates will be established to an SSL-enabled Flower server. + driver : Optional[Driver] (default: None) + The Driver object to use. Returns ------- @@ -101,19 +100,23 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals """ event(EventType.START_DRIVER_ENTER) - # Parse IP address - parsed_address = parse_address(server_address) - if not parsed_address: - sys.exit(f"Server IP address ({server_address}) cannot be parsed.") - host, port, is_v6 = parsed_address - address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}" + if driver is None: + # Not passing a `Driver` object is deprecated + warn_deprecated_feature("start_driver") - # Create the Driver - if isinstance(root_certificates, str): - root_certificates = Path(root_certificates).read_bytes() - driver = GrpcDriver(driver_service_address=address, certificates=root_certificates) - driver.connect() - lock = threading.Lock() + # Parse IP address + parsed_address = parse_address(server_address) + if not parsed_address: + sys.exit(f"Server IP address ({server_address}) cannot be parsed.") + host, port, is_v6 = parsed_address + address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}" + + # Create the Driver + if isinstance(root_certificates, str): + root_certificates = Path(root_certificates).read_bytes() + driver = Driver( + driver_service_address=address, root_certificates=root_certificates + ) # Initialize the Driver API server and config initialized_server, initialized_config = init_defaults( @@ -124,20 +127,15 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals ) log( INFO, - "Starting Flower server, config: %s", + "Starting Flower ServerApp, config: %s", initialized_config, ) + log(INFO, "") # Start the thread updating nodes - thread = threading.Thread( - target=update_client_manager, - args=( - driver, - initialized_server.client_manager(), - lock, - ), + thread, f_stop = start_update_client_manager_thread( + driver, initialized_server.client_manager() ) - thread.start() # Start training hist = run_fl( @@ -145,68 +143,13 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals config=initialized_config, ) + f_stop.set() + # Stop the Driver API server and the thread - with lock: - driver.disconnect() + del driver + thread.join() event(EventType.START_SERVER_LEAVE) return hist - - -def update_client_manager( - driver: GrpcDriver, - client_manager: ClientManager, - lock: threading.Lock, -) -> None: - """Update the nodes list in the client manager. - - This function periodically communicates with the associated driver to get all - node_ids. Each node_id is then converted into a `DriverClientProxy` instance - and stored in the `registered_nodes` dictionary with node_id as key. - - New nodes will be added to the ClientManager via `client_manager.register()`, - and dead nodes will be removed from the ClientManager via - `client_manager.unregister()`. - """ - # Request for run_id - run_id = driver.create_run( - driver_pb2.CreateRunRequest() # pylint: disable=E1101 - ).run_id - - # Loop until the driver is disconnected - registered_nodes: Dict[int, DriverClientProxy] = {} - while True: - with lock: - # End the while loop if the driver is disconnected - if driver.stub is None: - break - get_nodes_res = driver.get_nodes( - req=driver_pb2.GetNodesRequest(run_id=run_id) # pylint: disable=E1101 - ) - all_node_ids = {node.node_id for node in get_nodes_res.nodes} - dead_nodes = set(registered_nodes).difference(all_node_ids) - new_nodes = all_node_ids.difference(registered_nodes) - - # Unregister dead nodes - for node_id in dead_nodes: - client_proxy = registered_nodes[node_id] - client_manager.unregister(client_proxy) - del registered_nodes[node_id] - - # Register new nodes - for node_id in new_nodes: - client_proxy = DriverClientProxy( - node_id=node_id, - driver=driver, - anonymous=False, - run_id=run_id, - ) - if client_manager.register(client_proxy): - registered_nodes[node_id] = client_proxy - else: - raise RuntimeError("Could not register node.") - - # Sleep for 3 seconds - time.sleep(3) diff --git a/src/py/flwr/server/compat/app_utils.py b/src/py/flwr/server/compat/app_utils.py new file mode 100644 index 000000000000..696ec1132c4a --- /dev/null +++ b/src/py/flwr/server/compat/app_utils.py @@ -0,0 +1,102 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility functions for the `start_driver`.""" + + +import threading +import time +from typing import Dict, Tuple + +from ..client_manager import ClientManager +from ..compat.driver_client_proxy import DriverClientProxy +from ..driver import Driver + + +def start_update_client_manager_thread( + driver: Driver, + client_manager: ClientManager, +) -> Tuple[threading.Thread, threading.Event]: + """Periodically update the nodes list in the client manager in a thread. + + This function starts a thread that periodically uses the associated driver to + get all node_ids. Each node_id is then converted into a `DriverClientProxy` + instance and stored in the `registered_nodes` dictionary with node_id as key. + + New nodes will be added to the ClientManager via `client_manager.register()`, + and dead nodes will be removed from the ClientManager via + `client_manager.unregister()`. + + Parameters + ---------- + driver : Driver + The Driver object to use. + client_manager : ClientManager + The ClientManager object to be updated. + + Returns + ------- + threading.Thread + A thread that updates the ClientManager and handles the stop event. + threading.Event + An event that, when set, signals the thread to stop. + """ + f_stop = threading.Event() + thread = threading.Thread( + target=_update_client_manager, + args=( + driver, + client_manager, + f_stop, + ), + ) + thread.start() + + return thread, f_stop + + +def _update_client_manager( + driver: Driver, + client_manager: ClientManager, + f_stop: threading.Event, +) -> None: + """Update the nodes list in the client manager.""" + # Loop until the driver is disconnected + registered_nodes: Dict[int, DriverClientProxy] = {} + while not f_stop.is_set(): + all_node_ids = set(driver.get_node_ids()) + dead_nodes = set(registered_nodes).difference(all_node_ids) + new_nodes = all_node_ids.difference(registered_nodes) + + # Unregister dead nodes + for node_id in dead_nodes: + client_proxy = registered_nodes[node_id] + client_manager.unregister(client_proxy) + del registered_nodes[node_id] + + # Register new nodes + for node_id in new_nodes: + client_proxy = DriverClientProxy( + node_id=node_id, + driver=driver.grpc_driver, # type: ignore + anonymous=False, + run_id=driver.run_id, # type: ignore + ) + if client_manager.register(client_proxy): + registered_nodes[node_id] = client_proxy + else: + raise RuntimeError("Could not register node.") + + # Sleep for 3 seconds + time.sleep(3) diff --git a/src/py/flwr/server/compat/app_utils_test.py b/src/py/flwr/server/compat/app_utils_test.py new file mode 100644 index 000000000000..7e47e6eaaf32 --- /dev/null +++ b/src/py/flwr/server/compat/app_utils_test.py @@ -0,0 +1,62 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for utility functions for the `start_driver`.""" + + +import time +import unittest +from unittest.mock import Mock, patch + +from ..client_manager import SimpleClientManager +from .app_utils import start_update_client_manager_thread + + +class TestUtils(unittest.TestCase): + """Tests for utility functions.""" + + def test_start_update_client_manager_thread(self) -> None: + """Test start_update_client_manager_thread function.""" + # Prepare + sleep = time.sleep + sleep_patch = patch("time.sleep", lambda x: sleep(x / 100)) + sleep_patch.start() + expected_node_ids = list(range(100)) + updated_expected_node_ids = list(range(80, 120)) + driver = Mock() + driver.grpc_driver = Mock() + driver.run_id = 123 + driver.get_node_ids.return_value = expected_node_ids + client_manager = SimpleClientManager() + + # Execute + thread, f_stop = start_update_client_manager_thread(driver, client_manager) + # Wait until all nodes are registered via `client_manager.sample()` + client_manager.sample(len(expected_node_ids)) + # Retrieve all nodes in `client_manager` + node_ids = {proxy.node_id for proxy in client_manager.all().values()} + # Update the GetNodesResponse and wait until the `client_manager` is updated + driver.get_node_ids.return_value = updated_expected_node_ids + sleep(0.1) + # Retrieve all nodes in `client_manager` + updated_node_ids = {proxy.node_id for proxy in client_manager.all().values()} + # Stop the thread + f_stop.set() + + # Assert + assert node_ids == set(expected_node_ids) + assert updated_node_ids == set(updated_expected_node_ids) + + # Exit + thread.join() diff --git a/src/py/flwr/driver/driver_client_proxy.py b/src/py/flwr/server/compat/driver_client_proxy.py similarity index 59% rename from src/py/flwr/driver/driver_client_proxy.py rename to src/py/flwr/server/compat/driver_client_proxy.py index 8b2e51c17ea0..84c67149fad7 100644 --- a/src/py/flwr/driver/driver_client_proxy.py +++ b/src/py/flwr/server/compat/driver_client_proxy.py @@ -16,19 +16,16 @@ import time -from typing import List, Optional, cast +from typing import List, Optional from flwr import common +from flwr.common import MessageType, MessageTypeLegacy, RecordSet +from flwr.common import recordset_compat as compat from flwr.common import serde -from flwr.proto import ( # pylint: disable=E0611 - driver_pb2, - node_pb2, - task_pb2, - transport_pb2, -) +from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611 from flwr.server.client_proxy import ClientProxy -from .grpc_driver import GrpcDriver +from ..driver.grpc_driver import GrpcDriver SLEEP_TIME = 1 @@ -44,73 +41,82 @@ def __init__(self, node_id: int, driver: GrpcDriver, anonymous: bool, run_id: in self.anonymous = anonymous def get_properties( - self, ins: common.GetPropertiesIns, timeout: Optional[float] + self, + ins: common.GetPropertiesIns, + timeout: Optional[float], + group_id: Optional[int], ) -> common.GetPropertiesRes: """Return client's properties.""" - server_message_proto: transport_pb2.ServerMessage = ( # pylint: disable=E1101 - serde.server_message_to_proto( - server_message=common.ServerMessage(get_properties_ins=ins) - ) - ) - return cast( - common.GetPropertiesRes, - self._send_receive_msg(server_message_proto, timeout).get_properties_res, + # Ins to RecordSet + out_recordset = compat.getpropertiesins_to_recordset(ins) + # Fetch response + in_recordset = self._send_receive_recordset( + out_recordset, MessageTypeLegacy.GET_PROPERTIES, timeout, group_id ) + # RecordSet to Res + return compat.recordset_to_getpropertiesres(in_recordset) def get_parameters( - self, ins: common.GetParametersIns, timeout: Optional[float] + self, + ins: common.GetParametersIns, + timeout: Optional[float], + group_id: Optional[int], ) -> common.GetParametersRes: """Return the current local model parameters.""" - server_message_proto: transport_pb2.ServerMessage = ( # pylint: disable=E1101 - serde.server_message_to_proto( - server_message=common.ServerMessage(get_parameters_ins=ins) - ) - ) - return cast( - common.GetParametersRes, - self._send_receive_msg(server_message_proto, timeout).get_parameters_res, + # Ins to RecordSet + out_recordset = compat.getparametersins_to_recordset(ins) + # Fetch response + in_recordset = self._send_receive_recordset( + out_recordset, MessageTypeLegacy.GET_PARAMETERS, timeout, group_id ) + # RecordSet to Res + return compat.recordset_to_getparametersres(in_recordset, False) - def fit(self, ins: common.FitIns, timeout: Optional[float]) -> common.FitRes: + def fit( + self, ins: common.FitIns, timeout: Optional[float], group_id: Optional[int] + ) -> common.FitRes: """Train model parameters on the locally held dataset.""" - server_message_proto: transport_pb2.ServerMessage = ( # pylint: disable=E1101 - serde.server_message_to_proto( - server_message=common.ServerMessage(fit_ins=ins) - ) - ) - return cast( - common.FitRes, - self._send_receive_msg(server_message_proto, timeout).fit_res, + # Ins to RecordSet + out_recordset = compat.fitins_to_recordset(ins, keep_input=True) + # Fetch response + in_recordset = self._send_receive_recordset( + out_recordset, MessageType.TRAIN, timeout, group_id ) + # RecordSet to Res + return compat.recordset_to_fitres(in_recordset, keep_input=False) def evaluate( - self, ins: common.EvaluateIns, timeout: Optional[float] + self, ins: common.EvaluateIns, timeout: Optional[float], group_id: Optional[int] ) -> common.EvaluateRes: """Evaluate model parameters on the locally held dataset.""" - server_message_proto: transport_pb2.ServerMessage = ( # pylint: disable=E1101 - serde.server_message_to_proto( - server_message=common.ServerMessage(evaluate_ins=ins) - ) - ) - return cast( - common.EvaluateRes, - self._send_receive_msg(server_message_proto, timeout).evaluate_res, + # Ins to RecordSet + out_recordset = compat.evaluateins_to_recordset(ins, keep_input=True) + # Fetch response + in_recordset = self._send_receive_recordset( + out_recordset, MessageType.EVALUATE, timeout, group_id ) + # RecordSet to Res + return compat.recordset_to_evaluateres(in_recordset) def reconnect( - self, ins: common.ReconnectIns, timeout: Optional[float] + self, + ins: common.ReconnectIns, + timeout: Optional[float], + group_id: Optional[int], ) -> common.DisconnectRes: """Disconnect and (optionally) reconnect later.""" return common.DisconnectRes(reason="") # Nothing to do here (yet) - def _send_receive_msg( + def _send_receive_recordset( self, - server_message: transport_pb2.ServerMessage, # pylint: disable=E1101 + recordset: RecordSet, + task_type: str, timeout: Optional[float], - ) -> transport_pb2.ClientMessage: # pylint: disable=E1101 + group_id: Optional[int], + ) -> RecordSet: task_ins = task_pb2.TaskIns( # pylint: disable=E1101 task_id="", - group_id="", + group_id=str(group_id) if group_id is not None else "", run_id=self.run_id, task=task_pb2.Task( # pylint: disable=E1101 producer=node_pb2.Node( # pylint: disable=E1101 @@ -121,7 +127,8 @@ def _send_receive_msg( node_id=self.node_id, anonymous=self.anonymous, ), - legacy_server_message=server_message, + task_type=task_type, + recordset=serde.recordset_to_proto(recordset), ), ) push_task_ins_req = driver_pb2.PushTaskInsRequest( # pylint: disable=E1101 @@ -155,9 +162,7 @@ def _send_receive_msg( ) if len(task_res_list) == 1: task_res = task_res_list[0] - return serde.client_message_from_proto( # type: ignore - task_res.task.legacy_client_message - ) + return serde.recordset_from_proto(task_res.task.recordset) if timeout is not None and time.time() > start_time + timeout: raise RuntimeError("Timeout reached") diff --git a/src/py/flwr/driver/driver_client_proxy_test.py b/src/py/flwr/server/compat/driver_client_proxy_test.py similarity index 67% rename from src/py/flwr/driver/driver_client_proxy_test.py rename to src/py/flwr/server/compat/driver_client_proxy_test.py index d3cab152e4db..3494049c1064 100644 --- a/src/py/flwr/driver/driver_client_proxy_test.py +++ b/src/py/flwr/server/compat/driver_client_proxy_test.py @@ -16,23 +16,59 @@ import unittest +from typing import Union, cast from unittest.mock import MagicMock import numpy as np import flwr -from flwr.common.typing import Config, GetParametersIns -from flwr.driver.driver_client_proxy import DriverClientProxy -from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611 -from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 - ClientMessage, +from flwr.common import recordset_compat as compat +from flwr.common import serde +from flwr.common.constant import MessageType, MessageTypeLegacy +from flwr.common.typing import ( + Code, + Config, + EvaluateIns, + EvaluateRes, + FitRes, + GetParametersIns, + GetParametersRes, + GetPropertiesRes, Parameters, - Scalar, + Properties, + Status, ) +from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611 + +from .driver_client_proxy import DriverClientProxy MESSAGE_PARAMETERS = Parameters(tensors=[b"abc"], tensor_type="np") -CLIENT_PROPERTIES = {"tensor_type": Scalar(string="numpy.ndarray")} +CLIENT_PROPERTIES = cast(Properties, {"tensor_type": "numpy.ndarray"}) +CLIENT_STATUS = Status(code=Code.OK, message="OK") + + +def _make_task( + res: Union[GetParametersRes, GetPropertiesRes, FitRes, EvaluateRes] +) -> task_pb2.Task: # pylint: disable=E1101 + if isinstance(res, GetParametersRes): + message_type = MessageTypeLegacy.GET_PARAMETERS + recordset = compat.getparametersres_to_recordset(res, True) + elif isinstance(res, GetPropertiesRes): + message_type = MessageTypeLegacy.GET_PROPERTIES + recordset = compat.getpropertiesres_to_recordset(res) + elif isinstance(res, FitRes): + message_type = MessageType.TRAIN + recordset = compat.fitres_to_recordset(res, True) + elif isinstance(res, EvaluateRes): + message_type = MessageType.EVALUATE + recordset = compat.evaluateres_to_recordset(res) + else: + raise ValueError(f"Unsupported type: {type(res)}") + return task_pb2.Task( # pylint: disable=E1101 + task_type=message_type, + recordset=serde.recordset_to_proto(recordset), + ) class DriverClientProxyTestCase(unittest.TestCase): @@ -62,13 +98,11 @@ def test_get_properties(self) -> None: task_res_list=[ task_pb2.TaskRes( # pylint: disable=E1101 task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", - group_id="", + group_id=str(0), run_id=0, - task=task_pb2.Task( # pylint: disable=E1101 - legacy_client_message=ClientMessage( - get_properties_res=ClientMessage.GetPropertiesRes( - properties=CLIENT_PROPERTIES - ) + task=_make_task( + GetPropertiesRes( + status=CLIENT_STATUS, properties=CLIENT_PROPERTIES ) ), ) @@ -84,7 +118,9 @@ def test_get_properties(self) -> None: ) # Execute - value: flwr.common.GetPropertiesRes = client.get_properties(ins, timeout=None) + value: flwr.common.GetPropertiesRes = client.get_properties( + ins, timeout=None, group_id=0 + ) # Assert assert value.properties["tensor_type"] == "numpy.ndarray" @@ -102,13 +138,12 @@ def test_get_parameters(self) -> None: task_res_list=[ task_pb2.TaskRes( # pylint: disable=E1101 task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", - group_id="", + group_id=str(0), run_id=0, - task=task_pb2.Task( # pylint: disable=E1101 - legacy_client_message=ClientMessage( - get_parameters_res=ClientMessage.GetParametersRes( - parameters=MESSAGE_PARAMETERS, - ) + task=_make_task( + GetParametersRes( + status=CLIENT_STATUS, + parameters=MESSAGE_PARAMETERS, ) ), ) @@ -122,7 +157,7 @@ def test_get_parameters(self) -> None: # Execute value: flwr.common.GetParametersRes = client.get_parameters( - ins=get_parameters_ins, timeout=None + ins=get_parameters_ins, timeout=None, group_id=0 ) # Assert @@ -141,14 +176,14 @@ def test_fit(self) -> None: task_res_list=[ task_pb2.TaskRes( # pylint: disable=E1101 task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", - group_id="", + group_id=str(1), run_id=0, - task=task_pb2.Task( # pylint: disable=E1101 - legacy_client_message=ClientMessage( - fit_res=ClientMessage.FitRes( - parameters=MESSAGE_PARAMETERS, - num_examples=10, - ) + task=_make_task( + FitRes( + status=CLIENT_STATUS, + parameters=MESSAGE_PARAMETERS, + num_examples=10, + metrics={}, ) ), ) @@ -162,7 +197,7 @@ def test_fit(self) -> None: ins: flwr.common.FitIns = flwr.common.FitIns(parameters, {}) # Execute - fit_res = client.fit(ins=ins, timeout=None) + fit_res = client.fit(ins=ins, timeout=None, group_id=1) # Assert assert fit_res.parameters.tensor_type == "np" @@ -182,13 +217,14 @@ def test_evaluate(self) -> None: task_res_list=[ task_pb2.TaskRes( # pylint: disable=E1101 task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", - group_id="", + group_id=str(1), run_id=0, - task=task_pb2.Task( # pylint: disable=E1101 - legacy_client_message=ClientMessage( - evaluate_res=ClientMessage.EvaluateRes( - loss=0.0, num_examples=0 - ) + task=_make_task( + EvaluateRes( + status=CLIENT_STATUS, + loss=0.0, + num_examples=0, + metrics={}, ) ), ) @@ -198,11 +234,11 @@ def test_evaluate(self) -> None: client = DriverClientProxy( node_id=1, driver=self.driver, anonymous=True, run_id=0 ) - parameters = flwr.common.Parameters(tensors=[], tensor_type="np") - evaluate_ins: flwr.common.EvaluateIns = flwr.common.EvaluateIns(parameters, {}) + parameters = Parameters(tensors=[], tensor_type="np") + evaluate_ins = EvaluateIns(parameters, {}) # Execute - evaluate_res = client.evaluate(evaluate_ins, timeout=None) + evaluate_res = client.evaluate(evaluate_ins, timeout=None, group_id=1) # Assert assert 0.0 == evaluate_res.loss diff --git a/src/py/flwr/server/compat/legacy_context.py b/src/py/flwr/server/compat/legacy_context.py new file mode 100644 index 000000000000..0b00c98bb16d --- /dev/null +++ b/src/py/flwr/server/compat/legacy_context.py @@ -0,0 +1,55 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Legacy Context.""" + + +from dataclasses import dataclass +from typing import Optional + +from flwr.common import Context, RecordSet + +from ..client_manager import ClientManager, SimpleClientManager +from ..history import History +from ..server_config import ServerConfig +from ..strategy import FedAvg, Strategy + + +@dataclass +class LegacyContext(Context): + """Legacy Context.""" + + config: ServerConfig + strategy: Strategy + client_manager: ClientManager + history: History + + def __init__( + self, + state: RecordSet, + config: Optional[ServerConfig] = None, + strategy: Optional[Strategy] = None, + client_manager: Optional[ClientManager] = None, + ) -> None: + if config is None: + config = ServerConfig() + if strategy is None: + strategy = FedAvg() + if client_manager is None: + client_manager = SimpleClientManager() + self.config = config + self.strategy = strategy + self.client_manager = client_manager + self.history = History() + super().__init__(state) diff --git a/src/py/flwr/server/criterion_test.py b/src/py/flwr/server/criterion_test.py index a7e5b62b5977..f678825f064e 100644 --- a/src/py/flwr/server/criterion_test.py +++ b/src/py/flwr/server/criterion_test.py @@ -20,7 +20,7 @@ from flwr.server.client_manager import SimpleClientManager from flwr.server.client_proxy import ClientProxy from flwr.server.criterion import Criterion -from flwr.server.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy +from flwr.server.superlink.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy def test_criterion_applied() -> None: diff --git a/src/py/flwr/server/driver/__init__.py b/src/py/flwr/server/driver/__init__.py index 2bfe63e6065f..b61f6eebf6a8 100644 --- a/src/py/flwr/server/driver/__init__.py +++ b/src/py/flwr/server/driver/__init__.py @@ -12,4 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Flower driver service.""" +"""Flower driver SDK.""" + + +from .driver import Driver +from .grpc_driver import GrpcDriver + +__all__ = [ + "Driver", + "GrpcDriver", +] diff --git a/src/py/flwr/server/driver/driver.py b/src/py/flwr/server/driver/driver.py new file mode 100644 index 000000000000..bcaac1f61b85 --- /dev/null +++ b/src/py/flwr/server/driver/driver.py @@ -0,0 +1,256 @@ +# Copyright 2022 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower driver service client.""" + + +import time +from typing import Iterable, List, Optional, Tuple + +from flwr.common import Message, Metadata, RecordSet +from flwr.common.serde import message_from_taskres, message_to_taskins +from flwr.proto.driver_pb2 import ( # pylint: disable=E0611 + CreateRunRequest, + GetNodesRequest, + PullTaskResRequest, + PushTaskInsRequest, +) +from flwr.proto.node_pb2 import Node # pylint: disable=E0611 +from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611 + +from .grpc_driver import DEFAULT_SERVER_ADDRESS_DRIVER, GrpcDriver + + +class Driver: + """`Driver` class provides an interface to the Driver API. + + Parameters + ---------- + driver_service_address : Optional[str] + The IPv4 or IPv6 address of the Driver API server. + Defaults to `"[::]:9091"`. + certificates : bytes (default: None) + Tuple containing root certificate, server certificate, and private key + to start a secure SSL-enabled server. The tuple is expected to have + three bytes elements in the following order: + + * CA certificate. + * server certificate. + * server private key. + """ + + def __init__( + self, + driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER, + root_certificates: Optional[bytes] = None, + ) -> None: + self.addr = driver_service_address + self.root_certificates = root_certificates + self.grpc_driver: Optional[GrpcDriver] = None + self.run_id: Optional[int] = None + self.node = Node(node_id=0, anonymous=True) + + def _get_grpc_driver_and_run_id(self) -> Tuple[GrpcDriver, int]: + # Check if the GrpcDriver is initialized + if self.grpc_driver is None or self.run_id is None: + # Connect and create run + self.grpc_driver = GrpcDriver( + driver_service_address=self.addr, + root_certificates=self.root_certificates, + ) + self.grpc_driver.connect() + res = self.grpc_driver.create_run(CreateRunRequest()) + self.run_id = res.run_id + return self.grpc_driver, self.run_id + + def _check_message(self, message: Message) -> None: + # Check if the message is valid + if not ( + message.metadata.run_id == self.run_id + and message.metadata.src_node_id == self.node.node_id + and message.metadata.message_id == "" + and message.metadata.reply_to_message == "" + ): + raise ValueError(f"Invalid message: {message}") + + def create_message( # pylint: disable=too-many-arguments + self, + content: RecordSet, + message_type: str, + dst_node_id: int, + group_id: str, + ttl: str, + ) -> Message: + """Create a new message with specified parameters. + + This method constructs a new `Message` with given content and metadata. + The `run_id` and `src_node_id` will be set automatically. + + Parameters + ---------- + content : RecordSet + The content for the new message. This holds records that are to be sent + to the destination node. + message_type : str + The type of the message, defining the action to be executed on + the receiving end. + dst_node_id : int + The ID of the destination node to which the message is being sent. + group_id : str + The ID of the group to which this message is associated. In some settings, + this is used as the FL round. + ttl : str + Time-to-live for the round trip of this message, i.e., the time from sending + this message to receiving a reply. It specifies the duration for which the + message and its potential reply are considered valid. + + Returns + ------- + message : Message + A new `Message` instance with the specified content and metadata. + """ + _, run_id = self._get_grpc_driver_and_run_id() + metadata = Metadata( + run_id=run_id, + message_id="", # Will be set by the server + src_node_id=self.node.node_id, + dst_node_id=dst_node_id, + reply_to_message="", + group_id=group_id, + ttl=ttl, + message_type=message_type, + ) + return Message(metadata=metadata, content=content) + + def get_node_ids(self) -> List[int]: + """Get node IDs.""" + grpc_driver, run_id = self._get_grpc_driver_and_run_id() + # Call GrpcDriver method + res = grpc_driver.get_nodes(GetNodesRequest(run_id=run_id)) + return [node.node_id for node in res.nodes] + + def push_messages(self, messages: Iterable[Message]) -> Iterable[str]: + """Push messages to specified node IDs. + + This method takes an iterable of messages and sends each message + to the node specified in `dst_node_id`. + + Parameters + ---------- + messages : Iterable[Message] + An iterable of messages to be sent. + + Returns + ------- + message_ids : Iterable[str] + An iterable of IDs for the messages that were sent, which can be used + to pull replies. + """ + grpc_driver, _ = self._get_grpc_driver_and_run_id() + # Construct TaskIns + task_ins_list: List[TaskIns] = [] + for msg in messages: + # Check message + self._check_message(msg) + # Convert Message to TaskIns + taskins = message_to_taskins(msg) + # Add to list + task_ins_list.append(taskins) + # Call GrpcDriver method + res = grpc_driver.push_task_ins(PushTaskInsRequest(task_ins_list=task_ins_list)) + return list(res.task_ids) + + def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]: + """Pull messages based on message IDs. + + This method is used to collect messages from the SuperLink + that correspond to a set of given message IDs. + + Parameters + ---------- + message_ids : Iterable[str] + An iterable of message IDs for which reply messages are to be retrieved. + + Returns + ------- + messages : Iterable[Message] + An iterable of messages received. + """ + grpc_driver, _ = self._get_grpc_driver_and_run_id() + # Pull TaskRes + res = grpc_driver.pull_task_res( + PullTaskResRequest(node=self.node, task_ids=message_ids) + ) + # Convert TaskRes to Message + msgs = [message_from_taskres(taskres) for taskres in res.task_res_list] + return msgs + + def send_and_receive( + self, + messages: Iterable[Message], + *, + timeout: Optional[float] = None, + ) -> Iterable[Message]: + """Push messages to specified node IDs and pull the reply messages. + + This method sends a list of messages to their destination node IDs and then + waits for the replies. It continues to pull replies until either all + replies are received or the specified timeout duration is exceeded. + + Parameters + ---------- + messages : Iterable[Message] + An iterable of messages to be sent. + timeout : Optional[float] (default: None) + The timeout duration in seconds. If specified, the method will wait for + replies for this duration. If `None`, there is no time limit and the method + will wait until replies for all messages are received. + + Returns + ------- + replies : Iterable[Message] + An iterable of reply messages received from the SuperLink. + + Notes + ----- + This method uses `push_messages` to send the messages and `pull_messages` + to collect the replies. If `timeout` is set, the method may not return + replies for all sent messages. A message remains valid until its TTL, + which is not affected by `timeout`. + """ + # Push messages + msg_ids = set(self.push_messages(messages)) + + # Pull messages + end_time = time.time() + (timeout if timeout is not None else 0.0) + ret: List[Message] = [] + while timeout is None or time.time() < end_time: + res_msgs = self.pull_messages(msg_ids) + ret.extend(res_msgs) + msg_ids.difference_update( + {msg.metadata.reply_to_message for msg in res_msgs} + ) + if len(msg_ids) == 0: + break + # Sleep + time.sleep(3) + return ret + + def __del__(self) -> None: + """Disconnect GrpcDriver if connected.""" + # Check if GrpcDriver is initialized + if self.grpc_driver is None: + return + # Disconnect + self.grpc_driver.disconnect() diff --git a/src/py/flwr/driver/driver_test.py b/src/py/flwr/server/driver/driver_test.py similarity index 53% rename from src/py/flwr/driver/driver_test.py rename to src/py/flwr/server/driver/driver_test.py index 1854a92b5ebe..2bf253222f94 100644 --- a/src/py/flwr/driver/driver_test.py +++ b/src/py/flwr/server/driver/driver_test.py @@ -15,16 +15,21 @@ """Tests for driver SDK.""" +import time import unittest from unittest.mock import Mock, patch -from flwr.driver.driver import Driver +from flwr.common import RecordSet +from flwr.common.message import Error +from flwr.common.serde import error_to_proto, recordset_to_proto from flwr.proto.driver_pb2 import ( # pylint: disable=E0611 GetNodesRequest, PullTaskResRequest, PushTaskInsRequest, ) -from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 +from flwr.proto.task_pb2 import Task, TaskRes # pylint: disable=E0611 + +from .driver import Driver class TestDriver(unittest.TestCase): @@ -37,7 +42,7 @@ def setUp(self) -> None: self.mock_grpc_driver = Mock() self.mock_grpc_driver.create_run.return_value = mock_response self.patcher = patch( - "flwr.driver.driver.GrpcDriver", return_value=self.mock_grpc_driver + "flwr.server.driver.driver.GrpcDriver", return_value=self.mock_grpc_driver ) self.patcher.start() self.driver = Driver() @@ -73,11 +78,11 @@ def test_get_nodes(self) -> None: """Test retrieval of nodes.""" # Prepare mock_response = Mock() - mock_response.nodes = [Mock(), Mock()] + mock_response.nodes = [Mock(node_id=404), Mock(node_id=200)] self.mock_grpc_driver.get_nodes.return_value = mock_response # Execute - nodes = self.driver.get_nodes() + node_ids = self.driver.get_node_ids() args, kwargs = self.mock_grpc_driver.get_nodes.call_args # Assert @@ -86,18 +91,19 @@ def test_get_nodes(self) -> None: self.assertEqual(len(kwargs), 0) self.assertIsInstance(args[0], GetNodesRequest) self.assertEqual(args[0].run_id, 61016) - self.assertEqual(nodes, mock_response.nodes) + self.assertEqual(node_ids, [404, 200]) - def test_push_task_ins(self) -> None: - """Test pushing task instructions.""" + def test_push_messages_valid(self) -> None: + """Test pushing valid messages.""" # Prepare - mock_response = Mock() - mock_response.task_ids = ["id1", "id2"] + mock_response = Mock(task_ids=["id1", "id2"]) self.mock_grpc_driver.push_task_ins.return_value = mock_response - task_ins_list = [TaskIns(), TaskIns()] + msgs = [ + self.driver.create_message(RecordSet(), "", 0, "", "") for _ in range(2) + ] # Execute - task_ids = self.driver.push_task_ins(task_ins_list) + msg_ids = self.driver.push_messages(msgs) args, kwargs = self.mock_grpc_driver.push_task_ins.call_args # Assert @@ -105,23 +111,43 @@ def test_push_task_ins(self) -> None: self.assertEqual(len(args), 1) self.assertEqual(len(kwargs), 0) self.assertIsInstance(args[0], PushTaskInsRequest) - self.assertEqual(task_ids, mock_response.task_ids) + self.assertEqual(msg_ids, mock_response.task_ids) for task_ins in args[0].task_ins_list: self.assertEqual(task_ins.run_id, 61016) - def test_pull_task_res_with_given_task_ids(self) -> None: - """Test pulling task results with specific task IDs.""" + def test_push_messages_invalid(self) -> None: + """Test pushing invalid messages.""" + # Prepare + mock_response = Mock(task_ids=["id1", "id2"]) + self.mock_grpc_driver.push_task_ins.return_value = mock_response + msgs = [ + self.driver.create_message(RecordSet(), "", 0, "", "") for _ in range(2) + ] + # Use invalid run_id + msgs[1].metadata._run_id += 1 # pylint: disable=protected-access + + # Execute and assert + with self.assertRaises(ValueError): + self.driver.push_messages(msgs) + + def test_pull_messages_with_given_message_ids(self) -> None: + """Test pulling messages with specific message IDs.""" # Prepare mock_response = Mock() + # A Message must have either content or error set so we prepare + # two tasks that contain these. mock_response.task_res_list = [ - TaskRes(task=Task(ancestry=["id2"])), - TaskRes(task=Task(ancestry=["id3"])), + TaskRes( + task=Task(ancestry=["id2"], recordset=recordset_to_proto(RecordSet())) + ), + TaskRes(task=Task(ancestry=["id3"], error=error_to_proto(Error(code=0)))), ] self.mock_grpc_driver.pull_task_res.return_value = mock_response - task_ids = ["id1", "id2", "id3"] + msg_ids = ["id1", "id2", "id3"] # Execute - task_res_list = self.driver.pull_task_res(task_ids) + msgs = self.driver.pull_messages(msg_ids) + reply_tos = {msg.metadata.reply_to_message for msg in msgs} args, kwargs = self.mock_grpc_driver.pull_task_res.call_args # Assert @@ -129,8 +155,48 @@ def test_pull_task_res_with_given_task_ids(self) -> None: self.assertEqual(len(args), 1) self.assertEqual(len(kwargs), 0) self.assertIsInstance(args[0], PullTaskResRequest) - self.assertEqual(args[0].task_ids, task_ids) - self.assertEqual(task_res_list, mock_response.task_res_list) + self.assertEqual(args[0].task_ids, msg_ids) + self.assertEqual(reply_tos, {"id2", "id3"}) + + def test_send_and_receive_messages_complete(self) -> None: + """Test send and receive all messages successfully.""" + # Prepare + mock_response = Mock(task_ids=["id1"]) + self.mock_grpc_driver.push_task_ins.return_value = mock_response + # The response message must include either `content` (i.e. a recordset) or + # an `Error`. We choose the latter in this case + error_proto = error_to_proto(Error(code=0)) + mock_response = Mock( + task_res_list=[TaskRes(task=Task(ancestry=["id1"], error=error_proto))] + ) + self.mock_grpc_driver.pull_task_res.return_value = mock_response + msgs = [self.driver.create_message(RecordSet(), "", 0, "", "")] + + # Execute + ret_msgs = list(self.driver.send_and_receive(msgs)) + + # Assert + self.assertEqual(len(ret_msgs), 1) + self.assertEqual(ret_msgs[0].metadata.reply_to_message, "id1") + + def test_send_and_receive_messages_timeout(self) -> None: + """Test send and receive messages but time out.""" + # Prepare + sleep_fn = time.sleep + mock_response = Mock(task_ids=["id1"]) + self.mock_grpc_driver.push_task_ins.return_value = mock_response + mock_response = Mock(task_res_list=[]) + self.mock_grpc_driver.pull_task_res.return_value = mock_response + msgs = [self.driver.create_message(RecordSet(), "", 0, "", "")] + + # Execute + with patch("time.sleep", side_effect=lambda t: sleep_fn(t * 0.01)): + start_time = time.time() + ret_msgs = list(self.driver.send_and_receive(msgs, timeout=0.15)) + + # Assert + self.assertLess(time.time() - start_time, 0.2) + self.assertEqual(len(ret_msgs), 0) def test_del_with_initialized_driver(self) -> None: """Test cleanup behavior when Driver is initialized.""" diff --git a/src/py/flwr/driver/grpc_driver.py b/src/py/flwr/server/driver/grpc_driver.py similarity index 90% rename from src/py/flwr/driver/grpc_driver.py rename to src/py/flwr/server/driver/grpc_driver.py index 23d449790092..b6e2b2602cd5 100644 --- a/src/py/flwr/driver/grpc_driver.py +++ b/src/py/flwr/server/driver/grpc_driver.py @@ -15,7 +15,7 @@ """Flower driver service client.""" -from logging import ERROR, INFO, WARNING +from logging import DEBUG, ERROR, WARNING from typing import Optional import grpc @@ -51,10 +51,10 @@ class GrpcDriver: def __init__( self, driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER, - certificates: Optional[bytes] = None, + root_certificates: Optional[bytes] = None, ) -> None: self.driver_service_address = driver_service_address - self.certificates = certificates + self.root_certificates = root_certificates self.channel: Optional[grpc.Channel] = None self.stub: Optional[DriverStub] = None @@ -66,23 +66,23 @@ def connect(self) -> None: return self.channel = create_channel( server_address=self.driver_service_address, - insecure=(self.certificates is None), - root_certificates=self.certificates, + insecure=(self.root_certificates is None), + root_certificates=self.root_certificates, ) self.stub = DriverStub(self.channel) - log(INFO, "[Driver] Connected to %s", self.driver_service_address) + log(DEBUG, "[Driver] Connected to %s", self.driver_service_address) def disconnect(self) -> None: """Disconnect from the Driver API.""" event(EventType.DRIVER_DISCONNECT) if self.channel is None or self.stub is None: - log(WARNING, "Already disconnected") + log(DEBUG, "Already disconnected") return channel = self.channel self.channel = None self.stub = None channel.close() - log(INFO, "[Driver] Disconnected") + log(DEBUG, "[Driver] Disconnected") def create_run(self, req: CreateRunRequest) -> CreateRunResponse: """Request for run ID.""" diff --git a/src/py/flwr/server/fleet/grpc_bidi/driver_client_manager.py b/src/py/flwr/server/fleet/grpc_bidi/driver_client_manager.py deleted file mode 100644 index dc94bf3912d7..000000000000 --- a/src/py/flwr/server/fleet/grpc_bidi/driver_client_manager.py +++ /dev/null @@ -1,140 +0,0 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Flower DriverClientManager.""" - - -import threading -from typing import Dict, List, Optional, Set, Tuple - -from flwr.server.client_manager import ClientManager -from flwr.server.client_proxy import ClientProxy -from flwr.server.criterion import Criterion -from flwr.server.state import State, StateFactory - -from .ins_scheduler import InsScheduler - - -class DriverClientManager(ClientManager): - """Provides a pool of available clients.""" - - def __init__(self, state_factory: StateFactory) -> None: - self._cv = threading.Condition() - self.nodes: Dict[str, Tuple[int, InsScheduler]] = {} - self.state_factory = state_factory - - def __len__(self) -> int: - """Return the number of available clients. - - Returns - ------- - num_available : int - The number of currently available clients. - """ - return len(self.nodes) - - def num_available(self) -> int: - """Return the number of available clients. - - Returns - ------- - num_available : int - The number of currently available clients. - """ - return len(self) - - def register(self, client: ClientProxy) -> bool: - """Register Flower ClientProxy instance. - - Parameters - ---------- - client : flwr.server.client_proxy.ClientProxy - - Returns - ------- - success : bool - Indicating if registration was successful. False if ClientProxy is - already registered or can not be registered for any reason. - """ - if client.cid in self.nodes: - return False - - # Create node in State - state: State = self.state_factory.state() - client.node_id = state.create_node() - - # Create and start the instruction scheduler - ins_scheduler = InsScheduler( - client_proxy=client, - state_factory=self.state_factory, - ) - ins_scheduler.start() - - # Store cid, node_id, and InsScheduler - self.nodes[client.cid] = (client.node_id, ins_scheduler) - - with self._cv: - self._cv.notify_all() - - return True - - def unregister(self, client: ClientProxy) -> None: - """Unregister Flower ClientProxy instance. - - This method is idempotent. - - Parameters - ---------- - client : flwr.server.client_proxy.ClientProxy - """ - if client.cid in self.nodes: - node_id, ins_scheduler = self.nodes[client.cid] - del self.nodes[client.cid] - ins_scheduler.stop() - - # Delete node_id in State - state: State = self.state_factory.state() - state.delete_node(node_id=node_id) - - with self._cv: - self._cv.notify_all() - - def all_ids(self) -> Set[int]: - """Return all available node ids. - - Returns - ------- - ids : Set[int] - The IDs of all currently available nodes. - """ - return {node_id for _, (node_id, _) in self.nodes.items()} - - # --- Unimplemented methods ----------------------------------------------- - - def all(self) -> Dict[str, ClientProxy]: - """Not implemented.""" - raise NotImplementedError() - - def wait_for(self, num_clients: int, timeout: int = 86400) -> bool: - """Not implemented.""" - raise NotImplementedError() - - def sample( - self, - num_clients: int, - min_num_clients: Optional[int] = None, - criterion: Optional[Criterion] = None, - ) -> List[ClientProxy]: - """Not implemented.""" - raise NotImplementedError() diff --git a/src/py/flwr/server/fleet/grpc_bidi/ins_scheduler.py b/src/py/flwr/server/fleet/grpc_bidi/ins_scheduler.py deleted file mode 100644 index 5843934b64a4..000000000000 --- a/src/py/flwr/server/fleet/grpc_bidi/ins_scheduler.py +++ /dev/null @@ -1,174 +0,0 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Instruction scheduler for the legacy gRPC transport stack.""" - - -import threading -import time -from logging import DEBUG, ERROR -from typing import Dict, List, Optional - -from flwr.client.message_handler.task_handler import configure_task_res -from flwr.common import EvaluateRes, FitRes, GetParametersRes, GetPropertiesRes, serde -from flwr.common.logger import log -from flwr.proto.node_pb2 import Node # pylint: disable=E0611 -from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 -from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 - ClientMessage, - ServerMessage, -) -from flwr.server.client_proxy import ClientProxy -from flwr.server.state import State, StateFactory - - -class InsScheduler: - """Schedule ClientProxy calls on a background thread.""" - - def __init__(self, client_proxy: ClientProxy, state_factory: StateFactory): - self.client_proxy = client_proxy - self.state_factory = state_factory - self.worker_thread: Optional[threading.Thread] = None - self.shared_memory_state = {"stop": False} - - def start(self) -> None: - """Start the worker thread.""" - self.worker_thread = threading.Thread( - target=_worker, - args=( - self.client_proxy, - self.shared_memory_state, - self.state_factory, - ), - ) - self.worker_thread.start() - - def stop(self) -> None: - """Stop the worker thread.""" - if self.worker_thread is None: - log(ERROR, "InsScheduler.stop called, but worker_thread is None") - return - self.shared_memory_state["stop"] = True - self.worker_thread.join() - self.worker_thread = None - self.shared_memory_state["stop"] = False - - -def _worker( - client_proxy: ClientProxy, - shared_memory_state: Dict[str, bool], - state_factory: StateFactory, -) -> None: - """Sequentially call ClientProxy methods to process outstanding tasks.""" - log(DEBUG, "Worker for node %i started", client_proxy.node_id) - - state: State = state_factory.state() - while not shared_memory_state["stop"]: - log(DEBUG, "Worker for node %i checking state", client_proxy.node_id) - - # Step 1: pull *Ins (next task) out of `state` - task_ins_list: List[TaskIns] = state.get_task_ins( - node_id=client_proxy.node_id, - limit=1, - ) - if not task_ins_list: - log(DEBUG, "Worker for node %i: no task found", client_proxy.node_id) - time.sleep(3) - continue - - task_ins = task_ins_list[0] - log( - DEBUG, - "Worker for node %i: FOUND task %s", - client_proxy.node_id, - task_ins.task_id, - ) - - # Step 2: call client_proxy.{fit,evaluate,...} - server_message = task_ins.task.legacy_server_message - client_message_proto = _call_client_proxy( - client_proxy=client_proxy, - server_message=server_message, - timeout=None, - ) - - # Step 3: wrap FitRes in a ClientMessage in a Task in a TaskRes - task_res = configure_task_res( - TaskRes(task=Task(legacy_client_message=client_message_proto)), - task_ins, - Node(node_id=client_proxy.node_id, anonymous=False), - ) - - # Step 4: write *Res (result) back to `state` - state.store_task_res(task_res=task_res) - - # Exit worker thread - log(DEBUG, "Worker for node %i stopped", client_proxy.node_id) - - -def _call_client_proxy( - client_proxy: ClientProxy, server_message: ServerMessage, timeout: Optional[float] -) -> ClientMessage: - """.""" - # pylint: disable=too-many-locals - - field = server_message.WhichOneof("msg") - - if field == "get_properties_ins": - get_properties_ins = serde.get_properties_ins_from_proto( - msg=server_message.get_properties_ins - ) - get_properties_res: GetPropertiesRes = client_proxy.get_properties( - ins=get_properties_ins, - timeout=timeout, - ) - get_properties_res_proto = serde.get_properties_res_to_proto( - res=get_properties_res - ) - return ClientMessage(get_properties_res=get_properties_res_proto) - - if field == "get_parameters_ins": - get_parameters_ins = serde.get_parameters_ins_from_proto( - msg=server_message.get_parameters_ins - ) - get_parameters_res: GetParametersRes = client_proxy.get_parameters( - ins=get_parameters_ins, - timeout=timeout, - ) - get_parameters_res_proto = serde.get_parameters_res_to_proto( - res=get_parameters_res - ) - return ClientMessage(get_parameters_res=get_parameters_res_proto) - - if field == "fit_ins": - fit_ins = serde.fit_ins_from_proto(msg=server_message.fit_ins) - fit_res: FitRes = client_proxy.fit( - ins=fit_ins, - timeout=timeout, - ) - fit_res_proto = serde.fit_res_to_proto(res=fit_res) - return ClientMessage(fit_res=fit_res_proto) - - if field == "evaluate_ins": - evaluate_ins = serde.evaluate_ins_from_proto(msg=server_message.evaluate_ins) - evaluate_res: EvaluateRes = client_proxy.evaluate( - ins=evaluate_ins, - timeout=timeout, - ) - evaluate_res_proto = serde.evaluate_res_to_proto(res=evaluate_res) - return ClientMessage(evaluate_res=evaluate_res_proto) - - raise ValueError( - "Unsupported instruction in ServerMessage, cannot deserialize from ProtoBuf" - ) diff --git a/src/py/flwr/server/history.py b/src/py/flwr/server/history.py index ad5f5d0fc870..c4298911d97b 100644 --- a/src/py/flwr/server/history.py +++ b/src/py/flwr/server/history.py @@ -15,6 +15,7 @@ """Training history.""" +import pprint from functools import reduce from typing import Dict, List, Tuple @@ -90,29 +91,35 @@ def __repr__(self) -> str: """ rep = "" if self.losses_distributed: - rep += "History (loss, distributed):\n" + reduce( - lambda a, b: a + b, - [ - f"\tround {server_round}: {loss}\n" - for server_round, loss in self.losses_distributed - ], + rep += "History (loss, distributed):\n" + pprint.pformat( + reduce( + lambda a, b: a + b, + [ + f"\tround {server_round}: {loss}\n" + for server_round, loss in self.losses_distributed + ], + ) ) if self.losses_centralized: - rep += "History (loss, centralized):\n" + reduce( - lambda a, b: a + b, - [ - f"\tround {server_round}: {loss}\n" - for server_round, loss in self.losses_centralized - ], + rep += "History (loss, centralized):\n" + pprint.pformat( + reduce( + lambda a, b: a + b, + [ + f"\tround {server_round}: {loss}\n" + for server_round, loss in self.losses_centralized + ], + ) ) if self.metrics_distributed_fit: - rep += "History (metrics, distributed, fit):\n" + str( + rep += "History (metrics, distributed, fit):\n" + pprint.pformat( self.metrics_distributed_fit ) if self.metrics_distributed: - rep += "History (metrics, distributed, evaluate):\n" + str( + rep += "History (metrics, distributed, evaluate):\n" + pprint.pformat( self.metrics_distributed ) if self.metrics_centralized: - rep += "History (metrics, centralized):\n" + str(self.metrics_centralized) + rep += "History (metrics, centralized):\n" + pprint.pformat( + self.metrics_centralized + ) return rep diff --git a/src/py/flwr/server/run_serverapp.py b/src/py/flwr/server/run_serverapp.py new file mode 100644 index 000000000000..5b00d356886a --- /dev/null +++ b/src/py/flwr/server/run_serverapp.py @@ -0,0 +1,187 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Run ServerApp.""" + + +import argparse +import sys +from logging import DEBUG, INFO, WARN +from pathlib import Path +from typing import Optional + +from flwr.common import Context, EventType, RecordSet, event +from flwr.common.logger import log, update_console_handler +from flwr.common.object_ref import load_app + +from .driver.driver import Driver +from .server_app import LoadServerAppError, ServerApp + + +def run( + driver: Driver, + server_app_dir: str, + server_app_attr: Optional[str] = None, + loaded_server_app: Optional[ServerApp] = None, +) -> None: + """Run ServerApp with a given Driver.""" + if not (server_app_attr is None) ^ (loaded_server_app is None): + raise ValueError( + "Either `server_app_attr` or `loaded_server_app` should be set " + "but not both. " + ) + + if server_app_dir is not None: + sys.path.insert(0, server_app_dir) + + # Load ServerApp if needed + def _load() -> ServerApp: + if server_app_attr: + server_app: ServerApp = load_app(server_app_attr, LoadServerAppError) + + if not isinstance(server_app, ServerApp): + raise LoadServerAppError( + f"Attribute {server_app_attr} is not of type {ServerApp}", + ) from None + + if loaded_server_app: + server_app = loaded_server_app + return server_app + + server_app = _load() + + # Initialize Context + context = Context(state=RecordSet()) + + # Call ServerApp + server_app(driver=driver, context=context) + + log(DEBUG, "ServerApp finished running.") + + +def run_server_app() -> None: + """Run Flower server app.""" + event(EventType.RUN_SERVER_APP_ENTER) + + args = _parse_args_run_server_app().parse_args() + + update_console_handler( + level=DEBUG if args.verbose else INFO, + timestamps=args.verbose, + colored=True, + ) + + # Obtain certificates + if args.insecure: + if args.root_certificates is not None: + sys.exit( + "Conflicting options: The '--insecure' flag disables HTTPS, " + "but '--root-certificates' was also specified. Please remove " + "the '--root-certificates' option when running in insecure mode, " + "or omit '--insecure' to use HTTPS." + ) + log( + WARN, + "Option `--insecure` was set. " + "Starting insecure HTTP client connected to %s.", + args.server, + ) + root_certificates = None + else: + # Load the certificates if provided, or load the system certificates + cert_path = args.root_certificates + if cert_path is None: + root_certificates = None + else: + root_certificates = Path(cert_path).read_bytes() + log( + DEBUG, + "Starting secure HTTPS client connected to %s " + "with the following certificates: %s.", + args.server, + cert_path, + ) + + log( + DEBUG, + "Flower will load ServerApp `%s`", + getattr(args, "server-app"), + ) + + log( + DEBUG, + "root_certificates: `%s`", + root_certificates, + ) + + server_app_dir = args.dir + server_app_attr = getattr(args, "server-app") + + # Initialize Driver + driver = Driver( + driver_service_address=args.server, + root_certificates=root_certificates, + ) + + # Run the Server App with the Driver + run(driver=driver, server_app_dir=server_app_dir, server_app_attr=server_app_attr) + + # Clean up + driver.__del__() # pylint: disable=unnecessary-dunder-call + + event(EventType.RUN_SERVER_APP_LEAVE) + + +def _parse_args_run_server_app() -> argparse.ArgumentParser: + """Parse flower-server-app command line arguments.""" + parser = argparse.ArgumentParser( + description="Start a Flower server app", + ) + + parser.add_argument( + "server-app", + help="For example: `server:app` or `project.package.module:wrapper.app`", + ) + parser.add_argument( + "--insecure", + action="store_true", + help="Run the server app without HTTPS. By default, the app runs with " + "HTTPS enabled. Use this flag only if you understand the risks.", + ) + parser.add_argument( + "--verbose", + action="store_true", + help="Set the logging to `DEBUG`.", + ) + parser.add_argument( + "--root-certificates", + metavar="ROOT_CERT", + type=str, + help="Specifies the path to the PEM-encoded root certificate file for " + "establishing secure HTTPS connections.", + ) + parser.add_argument( + "--server", + default="0.0.0.0:9091", + help="Server address", + ) + parser.add_argument( + "--dir", + default="", + help="Add specified directory to the PYTHONPATH and load Flower " + "app from there." + " Default: current working directory.", + ) + + return parser diff --git a/src/py/flwr/server/server.py b/src/py/flwr/server/server.py index cf3a4d9aa07c..981325a6df08 100644 --- a/src/py/flwr/server/server.py +++ b/src/py/flwr/server/server.py @@ -16,8 +16,9 @@ import concurrent.futures +import io import timeit -from logging import DEBUG, INFO +from logging import INFO, WARN from typing import Dict, List, Optional, Tuple, Union from flwr.common import ( @@ -33,11 +34,13 @@ ) from flwr.common.logger import log from flwr.common.typing import GetParametersIns -from flwr.server.client_manager import ClientManager +from flwr.server.client_manager import ClientManager, SimpleClientManager from flwr.server.client_proxy import ClientProxy from flwr.server.history import History from flwr.server.strategy import FedAvg, Strategy +from .server_config import ServerConfig + FitResultsAndFailures = Tuple[ List[Tuple[ClientProxy, FitRes]], List[Union[Tuple[ClientProxy, FitRes], BaseException]], @@ -81,14 +84,14 @@ def client_manager(self) -> ClientManager: return self._client_manager # pylint: disable=too-many-locals - def fit(self, num_rounds: int, timeout: Optional[float]) -> History: + def fit(self, num_rounds: int, timeout: Optional[float]) -> Tuple[History, float]: """Run federated averaging for a number of rounds.""" history = History() # Initialize parameters - log(INFO, "Initializing global parameters") - self.parameters = self._get_initial_parameters(timeout=timeout) - log(INFO, "Evaluating initial parameters") + log(INFO, "[INIT]") + self.parameters = self._get_initial_parameters(server_round=0, timeout=timeout) + log(INFO, "Evaluating initial global parameters") res = self.strategy.evaluate(0, parameters=self.parameters) if res is not None: log( @@ -101,10 +104,11 @@ def fit(self, num_rounds: int, timeout: Optional[float]) -> History: history.add_metrics_centralized(server_round=0, metrics=res[1]) # Run federated learning for num_rounds - log(INFO, "FL starting") start_time = timeit.default_timer() for current_round in range(1, num_rounds + 1): + log(INFO, "") + log(INFO, "[ROUND %s]", current_round) # Train model and replace previous global model res_fit = self.fit_round( server_round=current_round, @@ -150,8 +154,7 @@ def fit(self, num_rounds: int, timeout: Optional[float]) -> History: # Bookkeeping end_time = timeit.default_timer() elapsed = end_time - start_time - log(INFO, "FL finished in %s", elapsed) - return history + return history, elapsed def evaluate_round( self, @@ -168,12 +171,11 @@ def evaluate_round( client_manager=self._client_manager, ) if not client_instructions: - log(INFO, "evaluate_round %s: no clients selected, cancel", server_round) + log(INFO, "configure_evaluate: no clients selected, skipping evaluation") return None log( - DEBUG, - "evaluate_round %s: strategy sampled %s clients (out of %s)", - server_round, + INFO, + "configure_evaluate: strategy sampled %s clients (out of %s)", len(client_instructions), self._client_manager.num_available(), ) @@ -183,11 +185,11 @@ def evaluate_round( client_instructions, max_workers=self.max_workers, timeout=timeout, + group_id=server_round, ) log( - DEBUG, - "evaluate_round %s received %s results and %s failures", - server_round, + INFO, + "aggregate_evaluate: received %s results and %s failures", len(results), len(failures), ) @@ -217,12 +219,11 @@ def fit_round( ) if not client_instructions: - log(INFO, "fit_round %s: no clients selected, cancel", server_round) + log(INFO, "configure_fit: no clients selected, cancel") return None log( - DEBUG, - "fit_round %s: strategy sampled %s clients (out of %s)", - server_round, + INFO, + "configure_fit: strategy sampled %s clients (out of %s)", len(client_instructions), self._client_manager.num_available(), ) @@ -232,11 +233,11 @@ def fit_round( client_instructions=client_instructions, max_workers=self.max_workers, timeout=timeout, + group_id=server_round, ) log( - DEBUG, - "fit_round %s received %s results and %s failures", - server_round, + INFO, + "aggregate_fit: received %s results and %s failures", len(results), len(failures), ) @@ -262,21 +263,25 @@ def disconnect_all_clients(self, timeout: Optional[float]) -> None: timeout=timeout, ) - def _get_initial_parameters(self, timeout: Optional[float]) -> Parameters: + def _get_initial_parameters( + self, server_round: int, timeout: Optional[float] + ) -> Parameters: """Get initial parameters from one of the available clients.""" # Server-side parameter initialization parameters: Optional[Parameters] = self.strategy.initialize_parameters( client_manager=self._client_manager ) if parameters is not None: - log(INFO, "Using initial parameters provided by strategy") + log(INFO, "Using initial global parameters provided by strategy") return parameters # Get initial parameters from one of the clients log(INFO, "Requesting initial parameters from one random client") random_client = self._client_manager.sample(1)[0] ins = GetParametersIns(config={}) - get_parameters_res = random_client.get_parameters(ins=ins, timeout=timeout) + get_parameters_res = random_client.get_parameters( + ins=ins, timeout=timeout, group_id=server_round + ) log(INFO, "Received initial parameters from one random client") return get_parameters_res.parameters @@ -319,6 +324,7 @@ def reconnect_client( disconnect = client.reconnect( reconnect, timeout=timeout, + group_id=None, ) return client, disconnect @@ -327,11 +333,12 @@ def fit_clients( client_instructions: List[Tuple[ClientProxy, FitIns]], max_workers: Optional[int], timeout: Optional[float], + group_id: int, ) -> FitResultsAndFailures: """Refine parameters concurrently on all selected clients.""" with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: submitted_fs = { - executor.submit(fit_client, client_proxy, ins, timeout) + executor.submit(fit_client, client_proxy, ins, timeout, group_id) for client_proxy, ins in client_instructions } finished_fs, _ = concurrent.futures.wait( @@ -350,10 +357,10 @@ def fit_clients( def fit_client( - client: ClientProxy, ins: FitIns, timeout: Optional[float] + client: ClientProxy, ins: FitIns, timeout: Optional[float], group_id: int ) -> Tuple[ClientProxy, FitRes]: """Refine parameters on a single client.""" - fit_res = client.fit(ins, timeout=timeout) + fit_res = client.fit(ins, timeout=timeout, group_id=group_id) return client, fit_res @@ -386,11 +393,12 @@ def evaluate_clients( client_instructions: List[Tuple[ClientProxy, EvaluateIns]], max_workers: Optional[int], timeout: Optional[float], + group_id: int, ) -> EvaluateResultsAndFailures: """Evaluate parameters concurrently on all selected clients.""" with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: submitted_fs = { - executor.submit(evaluate_client, client_proxy, ins, timeout) + executor.submit(evaluate_client, client_proxy, ins, timeout, group_id) for client_proxy, ins in client_instructions } finished_fs, _ = concurrent.futures.wait( @@ -412,9 +420,10 @@ def evaluate_client( client: ClientProxy, ins: EvaluateIns, timeout: Optional[float], + group_id: int, ) -> Tuple[ClientProxy, EvaluateRes]: """Evaluate parameters on a single client.""" - evaluate_res = client.evaluate(ins, timeout=timeout) + evaluate_res = client.evaluate(ins, timeout=timeout, group_id=group_id) return client, evaluate_res @@ -441,3 +450,51 @@ def _handle_finished_future_after_evaluate( # Not successful, client returned a result where the status code is not OK failures.append(result) + + +def init_defaults( + server: Optional[Server], + config: Optional[ServerConfig], + strategy: Optional[Strategy], + client_manager: Optional[ClientManager], +) -> Tuple[Server, ServerConfig]: + """Create server instance if none was given.""" + if server is None: + if client_manager is None: + client_manager = SimpleClientManager() + if strategy is None: + strategy = FedAvg() + server = Server(client_manager=client_manager, strategy=strategy) + elif strategy is not None: + log(WARN, "Both server and strategy were provided, ignoring strategy") + + # Set default config values + if config is None: + config = ServerConfig() + + return server, config + + +def run_fl( + server: Server, + config: ServerConfig, +) -> History: + """Train a model on the given server and return the History object.""" + hist, elapsed_time = server.fit( + num_rounds=config.num_rounds, timeout=config.round_timeout + ) + + log(INFO, "") + log(INFO, "[SUMMARY]") + log(INFO, "Run finished %s rounds in %.2fs", config.num_rounds, elapsed_time) + for idx, line in enumerate(io.StringIO(str(hist))): + if idx == 0: + log(INFO, "%s", line.strip("\n")) + else: + log(INFO, "\t%s", line.strip("\n")) + log(INFO, "") + + # Graceful shutdown + server.disconnect_all_clients(timeout=config.round_timeout) + + return hist diff --git a/src/py/flwr/server/server_app.py b/src/py/flwr/server/server_app.py new file mode 100644 index 000000000000..1b2eab87fdaa --- /dev/null +++ b/src/py/flwr/server/server_app.py @@ -0,0 +1,133 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower ServerApp.""" + + +from typing import Callable, Optional + +from flwr.common import Context, RecordSet +from flwr.server.strategy import Strategy + +from .client_manager import ClientManager +from .compat import start_driver +from .driver import Driver +from .server import Server +from .server_config import ServerConfig +from .typing import ServerAppCallable + + +class ServerApp: + """Flower ServerApp. + + Examples + -------- + Use the `ServerApp` with an existing `Strategy`: + + >>> server_config = ServerConfig(num_rounds=3) + >>> strategy = FedAvg() + >>> + >>> app = ServerApp() + >>> server_config=server_config, + >>> strategy=strategy, + >>> ) + + Use the `ServerApp` with a custom main function: + + >>> app = ServerApp() + >>> + >>> @app.main() + >>> def main(driver: Driver, context: Context) -> None: + >>> print("ServerApp running") + """ + + def __init__( + self, + server: Optional[Server] = None, + config: Optional[ServerConfig] = None, + strategy: Optional[Strategy] = None, + client_manager: Optional[ClientManager] = None, + ) -> None: + self._server = server + self._config = config + self._strategy = strategy + self._client_manager = client_manager + self._main: Optional[ServerAppCallable] = None + + def __call__(self, driver: Driver, context: Context) -> None: + """Execute `ServerApp`.""" + # Compatibility mode + if not self._main: + start_driver( + server=self._server, + config=self._config, + strategy=self._strategy, + client_manager=self._client_manager, + driver=driver, + ) + return + + # New execution mode + context = Context(state=RecordSet()) + self._main(driver, context) + + def main(self) -> Callable[[ServerAppCallable], ServerAppCallable]: + """Return a decorator that registers the main fn with the server app. + + Examples + -------- + >>> app = ServerApp() + >>> + >>> @app.main() + >>> def main(driver: Driver, context: Context) -> None: + >>> print("ServerApp running") + """ + + def main_decorator(main_fn: ServerAppCallable) -> ServerAppCallable: + """Register the main fn with the ServerApp object.""" + if self._server or self._config or self._strategy or self._client_manager: + raise ValueError( + """Use either a custom main function or a `Strategy`, but not both. + + Use the `ServerApp` with an existing `Strategy`: + + >>> server_config = ServerConfig(num_rounds=3) + >>> strategy = FedAvg() + >>> + >>> app = ServerApp() + >>> server_config=server_config, + >>> strategy=strategy, + >>> ) + + Use the `ServerApp` with a custom main function: + + >>> app = ServerApp() + >>> + >>> @app.main() + >>> def main(driver: Driver, context: Context) -> None: + >>> print("ServerApp running") + """, + ) + + # Register provided function with the ServerApp object + self._main = main_fn + + # Return provided function unmodified + return main_fn + + return main_decorator + + +class LoadServerAppError(Exception): + """Error when trying to load `ServerApp`.""" diff --git a/src/py/flwr/server/server_app_test.py b/src/py/flwr/server/server_app_test.py new file mode 100644 index 000000000000..38c0d6240d90 --- /dev/null +++ b/src/py/flwr/server/server_app_test.py @@ -0,0 +1,62 @@ +# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for ServerApp.""" + + +from unittest.mock import MagicMock + +import pytest + +from flwr.common import Context, RecordSet +from flwr.server import ServerApp, ServerConfig +from flwr.server.driver import Driver + + +def test_server_app_custom_mode() -> None: + """Test sampling w/o criterion.""" + # Prepare + app = ServerApp() + driver = MagicMock() + context = Context(state=RecordSet()) + + called = {"called": False} + + # pylint: disable=unused-argument + @app.main() + def custom_main(driver: Driver, context: Context) -> None: + called["called"] = True + + # pylint: enable=unused-argument + + # Execute + app(driver, context) + + # Assert + assert called["called"] + + +def test_server_app_exception_when_both_modes() -> None: + """Test ServerApp error when both compat mode and custom fns are used.""" + # Prepare + app = ServerApp(config=ServerConfig(num_rounds=3)) + + # Execute and assert + with pytest.raises(ValueError): + # pylint: disable=unused-argument + @app.main() + def custom_main(driver: Driver, context: Context) -> None: + pass + + # pylint: enable=unused-argument diff --git a/src/py/flwr/server/server_config.py b/src/py/flwr/server/server_config.py new file mode 100644 index 000000000000..c47367eab4c0 --- /dev/null +++ b/src/py/flwr/server/server_config.py @@ -0,0 +1,40 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower ServerConfig.""" + + +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class ServerConfig: + """Flower server config. + + All attributes have default values which allows users to configure just the ones + they care about. + """ + + num_rounds: int = 1 + round_timeout: Optional[float] = None + + def __repr__(self) -> str: + """Return the string representation of the ServerConfig.""" + timeout_string = ( + "no round_timeout" + if self.round_timeout is None + else f"round_timeout={self.round_timeout}s" + ) + return f"num_rounds={self.num_rounds}, {timeout_string}" diff --git a/src/py/flwr/server/server_test.py b/src/py/flwr/server/server_test.py index 9b5c03aeeaf9..274e5289fee1 100644 --- a/src/py/flwr/server/server_test.py +++ b/src/py/flwr/server/server_test.py @@ -45,18 +45,20 @@ class SuccessClient(ClientProxy): """Test class.""" def get_properties( - self, ins: GetPropertiesIns, timeout: Optional[float] + self, ins: GetPropertiesIns, timeout: Optional[float], group_id: Optional[int] ) -> GetPropertiesRes: """Raise an error because this method is not expected to be called.""" raise NotImplementedError() def get_parameters( - self, ins: GetParametersIns, timeout: Optional[float] + self, ins: GetParametersIns, timeout: Optional[float], group_id: Optional[int] ) -> GetParametersRes: """Raise a error because this method is not expected to be called.""" raise NotImplementedError() - def fit(self, ins: FitIns, timeout: Optional[float]) -> FitRes: + def fit( + self, ins: FitIns, timeout: Optional[float], group_id: Optional[int] + ) -> FitRes: """Simulate fit by returning a success FitRes with simple set of weights.""" arr = np.array([[1, 2], [3, 4], [5, 6]]) arr_serialized = ndarray_to_bytes(arr) @@ -67,7 +69,9 @@ def fit(self, ins: FitIns, timeout: Optional[float]) -> FitRes: metrics={}, ) - def evaluate(self, ins: EvaluateIns, timeout: Optional[float]) -> EvaluateRes: + def evaluate( + self, ins: EvaluateIns, timeout: Optional[float], group_id: Optional[int] + ) -> EvaluateRes: """Simulate evaluate by returning a success EvaluateRes with loss 1.0.""" return EvaluateRes( status=Status(code=Code.OK, message="Success"), @@ -76,7 +80,9 @@ def evaluate(self, ins: EvaluateIns, timeout: Optional[float]) -> EvaluateRes: metrics={}, ) - def reconnect(self, ins: ReconnectIns, timeout: Optional[float]) -> DisconnectRes: + def reconnect( + self, ins: ReconnectIns, timeout: Optional[float], group_id: Optional[int] + ) -> DisconnectRes: """Simulate reconnect by returning a DisconnectRes with UNKNOWN reason.""" return DisconnectRes(reason="UNKNOWN") @@ -85,26 +91,32 @@ class FailingClient(ClientProxy): """Test class.""" def get_properties( - self, ins: GetPropertiesIns, timeout: Optional[float] + self, ins: GetPropertiesIns, timeout: Optional[float], group_id: Optional[int] ) -> GetPropertiesRes: """Raise a NotImplementedError to simulate failure in the client.""" raise NotImplementedError() def get_parameters( - self, ins: GetParametersIns, timeout: Optional[float] + self, ins: GetParametersIns, timeout: Optional[float], group_id: Optional[int] ) -> GetParametersRes: """Raise a NotImplementedError to simulate failure in the client.""" raise NotImplementedError() - def fit(self, ins: FitIns, timeout: Optional[float]) -> FitRes: + def fit( + self, ins: FitIns, timeout: Optional[float], group_id: Optional[int] + ) -> FitRes: """Raise a NotImplementedError to simulate failure in the client.""" raise NotImplementedError() - def evaluate(self, ins: EvaluateIns, timeout: Optional[float]) -> EvaluateRes: + def evaluate( + self, ins: EvaluateIns, timeout: Optional[float], group_id: Optional[int] + ) -> EvaluateRes: """Raise a NotImplementedError to simulate failure in the client.""" raise NotImplementedError() - def reconnect(self, ins: ReconnectIns, timeout: Optional[float]) -> DisconnectRes: + def reconnect( + self, ins: ReconnectIns, timeout: Optional[float], group_id: Optional[int] + ) -> DisconnectRes: """Raise a NotImplementedError to simulate failure in the client.""" raise NotImplementedError() @@ -122,7 +134,7 @@ def test_fit_clients() -> None: client_instructions = [(c, ins) for c in clients] # Execute - results, failures = fit_clients(client_instructions, None, None) + results, failures = fit_clients(client_instructions, None, None, 0) # Assert assert len(results) == 1 @@ -150,6 +162,7 @@ def test_eval_clients() -> None: client_instructions=client_instructions, max_workers=None, timeout=None, + group_id=0, ) # Assert diff --git a/src/py/flwr/server/strategy/__init__.py b/src/py/flwr/server/strategy/__init__.py index 1750a7522379..b7de9a946fff 100644 --- a/src/py/flwr/server/strategy/__init__.py +++ b/src/py/flwr/server/strategy/__init__.py @@ -16,6 +16,18 @@ from .bulyan import Bulyan as Bulyan +from .dp_adaptive_clipping import ( + DifferentialPrivacyClientSideAdaptiveClipping as DifferentialPrivacyClientSideAdaptiveClipping, +) +from .dp_adaptive_clipping import ( + DifferentialPrivacyServerSideAdaptiveClipping as DifferentialPrivacyServerSideAdaptiveClipping, +) +from .dp_fixed_clipping import ( + DifferentialPrivacyClientSideFixedClipping as DifferentialPrivacyClientSideFixedClipping, +) +from .dp_fixed_clipping import ( + DifferentialPrivacyServerSideFixedClipping as DifferentialPrivacyServerSideFixedClipping, +) from .dpfedavg_adaptive import DPFedAvgAdaptive as DPFedAvgAdaptive from .dpfedavg_fixed import DPFedAvgFixed as DPFedAvgFixed from .fault_tolerant_fedavg import FaultTolerantFedAvg as FaultTolerantFedAvg @@ -37,24 +49,28 @@ from .strategy import Strategy as Strategy __all__ = [ - "FaultTolerantFedAvg", + "Bulyan", + "DPFedAvgAdaptive", + "DPFedAvgFixed", + "DifferentialPrivacyClientSideAdaptiveClipping", + "DifferentialPrivacyServerSideAdaptiveClipping", + "DifferentialPrivacyClientSideFixedClipping", + "DifferentialPrivacyServerSideFixedClipping", "FedAdagrad", "FedAdam", "FedAvg", - "FedXgbNnAvg", - "FedXgbBagging", - "FedXgbCyclic", "FedAvgAndroid", "FedAvgM", + "FedMedian", "FedOpt", "FedProx", - "FedYogi", - "QFedAvg", - "FedMedian", "FedTrimmedAvg", + "FedXgbBagging", + "FedXgbCyclic", + "FedXgbNnAvg", + "FedYogi", + "FaultTolerantFedAvg", "Krum", - "Bulyan", - "DPFedAvgAdaptive", - "DPFedAvgFixed", + "QFedAvg", "Strategy", ] diff --git a/src/py/flwr/server/strategy/dp_adaptive_clipping.py b/src/py/flwr/server/strategy/dp_adaptive_clipping.py new file mode 100644 index 000000000000..d9422c791167 --- /dev/null +++ b/src/py/flwr/server/strategy/dp_adaptive_clipping.py @@ -0,0 +1,449 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Central differential privacy with adaptive clipping. + +Paper (Andrew et al.): https://arxiv.org/abs/1905.03871 +""" + + +import math +from logging import WARNING +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np + +from flwr.common import ( + EvaluateIns, + EvaluateRes, + FitIns, + FitRes, + NDArrays, + Parameters, + Scalar, + ndarrays_to_parameters, + parameters_to_ndarrays, +) +from flwr.common.differential_privacy import ( + adaptive_clip_inputs_inplace, + add_gaussian_noise_to_params, + compute_adaptive_noise_params, +) +from flwr.common.differential_privacy_constants import ( + CLIENTS_DISCREPANCY_WARNING, + KEY_CLIPPING_NORM, + KEY_NORM_BIT, +) +from flwr.common.logger import log +from flwr.server.client_manager import ClientManager +from flwr.server.client_proxy import ClientProxy +from flwr.server.strategy.strategy import Strategy + + +class DifferentialPrivacyServerSideAdaptiveClipping(Strategy): + """Strategy wrapper for central DP with server-side adaptive clipping. + + Parameters + ---------- + strategy: Strategy + The strategy to which DP functionalities will be added by this wrapper. + noise_multiplier : float + The noise multiplier for the Gaussian mechanism for model updates. + num_sampled_clients : int + The number of clients that are sampled on each round. + initial_clipping_norm : float + The initial value of clipping norm. Defaults to 0.1. + Andrew et al. recommends to set to 0.1. + target_clipped_quantile : float + The desired quantile of updates which should be clipped. Defaults to 0.5. + clip_norm_lr : float + The learning rate for the clipping norm adaptation. Defaults to 0.2. + Andrew et al. recommends to set to 0.2. + clipped_count_stddev : float + The standard deviation of the noise added to the count of updates below the estimate. + Andrew et al. recommends to set to `expected_num_records/20` + + Examples + -------- + Create a strategy: + + >>> strategy = fl.server.strategy.FedAvg( ... ) + + Wrap the strategy with the DifferentialPrivacyServerSideAdaptiveClipping wrapper + + >>> dp_strategy = DifferentialPrivacyServerSideAdaptiveClipping( + >>> strategy, cfg.noise_multiplier, cfg.num_sampled_clients, ... + >>> ) + """ + + # pylint: disable=too-many-arguments,too-many-instance-attributes + def __init__( + self, + strategy: Strategy, + noise_multiplier: float, + num_sampled_clients: int, + initial_clipping_norm: float = 0.1, + target_clipped_quantile: float = 0.5, + clip_norm_lr: float = 0.2, + clipped_count_stddev: Optional[float] = None, + ) -> None: + super().__init__() + + if strategy is None: + raise ValueError("The passed strategy is None.") + + if noise_multiplier < 0: + raise ValueError("The noise multiplier should be a non-negative value.") + + if num_sampled_clients <= 0: + raise ValueError( + "The number of sampled clients should be a positive value." + ) + + if initial_clipping_norm <= 0: + raise ValueError("The initial clipping norm should be a positive value.") + + if not 0 <= target_clipped_quantile <= 1: + raise ValueError( + "The target clipped quantile must be between 0 and 1 (inclusive)." + ) + + if clip_norm_lr <= 0: + raise ValueError("The learning rate must be positive.") + + if clipped_count_stddev is not None: + if clipped_count_stddev < 0: + raise ValueError("The `clipped_count_stddev` must be non-negative.") + + self.strategy = strategy + self.num_sampled_clients = num_sampled_clients + self.clipping_norm = initial_clipping_norm + self.target_clipped_quantile = target_clipped_quantile + self.clip_norm_lr = clip_norm_lr + ( + self.clipped_count_stddev, + self.noise_multiplier, + ) = compute_adaptive_noise_params( + noise_multiplier, + num_sampled_clients, + clipped_count_stddev, + ) + + self.current_round_params: NDArrays = [] + + def __repr__(self) -> str: + """Compute a string representation of the strategy.""" + rep = "Differential Privacy Strategy Wrapper (Server-Side Adaptive Clipping)" + return rep + + def initialize_parameters( + self, client_manager: ClientManager + ) -> Optional[Parameters]: + """Initialize global model parameters using given strategy.""" + return self.strategy.initialize_parameters(client_manager) + + def configure_fit( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, FitIns]]: + """Configure the next round of training.""" + self.current_round_params = parameters_to_ndarrays(parameters) + return self.strategy.configure_fit(server_round, parameters, client_manager) + + def configure_evaluate( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, EvaluateIns]]: + """Configure the next round of evaluation.""" + return self.strategy.configure_evaluate( + server_round, parameters, client_manager + ) + + def aggregate_fit( + self, + server_round: int, + results: List[Tuple[ClientProxy, FitRes]], + failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], + ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + """Aggregate training results and update clip norms.""" + if failures: + return None, {} + + if len(results) != self.num_sampled_clients: + log( + WARNING, + CLIENTS_DISCREPANCY_WARNING, + len(results), + self.num_sampled_clients, + ) + + norm_bit_set_count = 0 + for _, res in results: + param = parameters_to_ndarrays(res.parameters) + # Compute and clip update + model_update = [ + np.subtract(x, y) for (x, y) in zip(param, self.current_round_params) + ] + + norm_bit = adaptive_clip_inputs_inplace(model_update, self.clipping_norm) + norm_bit_set_count += norm_bit + + for i, _ in enumerate(self.current_round_params): + param[i] = self.current_round_params[i] + model_update[i] + # Convert back to parameters + res.parameters = ndarrays_to_parameters(param) + + # Noising the count + noised_norm_bit_set_count = float( + np.random.normal(norm_bit_set_count, self.clipped_count_stddev) + ) + noised_norm_bit_set_fraction = noised_norm_bit_set_count / len(results) + # Geometric update + self.clipping_norm *= math.exp( + -self.clip_norm_lr + * (noised_norm_bit_set_fraction - self.target_clipped_quantile) + ) + + aggregated_params, metrics = self.strategy.aggregate_fit( + server_round, results, failures + ) + + # Add Gaussian noise to the aggregated parameters + if aggregated_params: + aggregated_params = add_gaussian_noise_to_params( + aggregated_params, + self.noise_multiplier, + self.clipping_norm, + self.num_sampled_clients, + ) + + return aggregated_params, metrics + + def aggregate_evaluate( + self, + server_round: int, + results: List[Tuple[ClientProxy, EvaluateRes]], + failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> Tuple[Optional[float], Dict[str, Scalar]]: + """Aggregate evaluation losses using the given strategy.""" + return self.strategy.aggregate_evaluate(server_round, results, failures) + + def evaluate( + self, server_round: int, parameters: Parameters + ) -> Optional[Tuple[float, Dict[str, Scalar]]]: + """Evaluate model parameters using an evaluation function from the strategy.""" + return self.strategy.evaluate(server_round, parameters) + + +class DifferentialPrivacyClientSideAdaptiveClipping(Strategy): + """Strategy wrapper for central DP with client-side adaptive clipping. + + Use `adaptiveclipping_mod` modifier at the client side. + + In comparison to `DifferentialPrivacyServerSideAdaptiveClipping`, + which performs clipping on the server-side, `DifferentialPrivacyClientSideAdaptiveClipping` + expects clipping to happen on the client-side, usually by using the built-in + `adaptiveclipping_mod`. + + Parameters + ---------- + strategy : Strategy + The strategy to which DP functionalities will be added by this wrapper. + noise_multiplier : float + The noise multiplier for the Gaussian mechanism for model updates. + num_sampled_clients : int + The number of clients that are sampled on each round. + initial_clipping_norm : float + The initial value of clipping norm. Defaults to 0.1. + Andrew et al. recommends to set to 0.1. + target_clipped_quantile : float + The desired quantile of updates which should be clipped. Defaults to 0.5. + clip_norm_lr : float + The learning rate for the clipping norm adaptation. Defaults to 0.2. + Andrew et al. recommends to set to 0.2. + clipped_count_stddev : float + The stddev of the noise added to the count of updates currently below the estimate. + Andrew et al. recommends to set to `expected_num_records/20` + + Examples + -------- + Create a strategy: + + >>> strategy = fl.server.strategy.FedAvg(...) + + Wrap the strategy with the `DifferentialPrivacyClientSideAdaptiveClipping` wrapper: + + >>> DifferentialPrivacyClientSideAdaptiveClipping( + >>> strategy, cfg.noise_multiplier, cfg.num_sampled_clients + >>> ) + + On the client, add the `adaptiveclipping_mod` to the client-side mods: + + >>> app = fl.client.ClientApp( + >>> client_fn=client_fn, mods=[adaptiveclipping_mod] + >>> ) + """ + + # pylint: disable=too-many-arguments,too-many-instance-attributes + def __init__( + self, + strategy: Strategy, + noise_multiplier: float, + num_sampled_clients: int, + initial_clipping_norm: float = 0.1, + target_clipped_quantile: float = 0.5, + clip_norm_lr: float = 0.2, + clipped_count_stddev: Optional[float] = None, + ) -> None: + super().__init__() + + if strategy is None: + raise ValueError("The passed strategy is None.") + + if noise_multiplier < 0: + raise ValueError("The noise multiplier should be a non-negative value.") + + if num_sampled_clients <= 0: + raise ValueError( + "The number of sampled clients should be a positive value." + ) + + if initial_clipping_norm <= 0: + raise ValueError("The initial clipping norm should be a positive value.") + + if not 0 <= target_clipped_quantile <= 1: + raise ValueError( + "The target clipped quantile must be between 0 and 1 (inclusive)." + ) + + if clip_norm_lr <= 0: + raise ValueError("The learning rate must be positive.") + + if clipped_count_stddev is not None and clipped_count_stddev < 0: + raise ValueError("The `clipped_count_stddev` must be non-negative.") + + self.strategy = strategy + self.num_sampled_clients = num_sampled_clients + self.clipping_norm = initial_clipping_norm + self.target_clipped_quantile = target_clipped_quantile + self.clip_norm_lr = clip_norm_lr + ( + self.clipped_count_stddev, + self.noise_multiplier, + ) = compute_adaptive_noise_params( + noise_multiplier, + num_sampled_clients, + clipped_count_stddev, + ) + + def __repr__(self) -> str: + """Compute a string representation of the strategy.""" + rep = "Differential Privacy Strategy Wrapper (Client-Side Adaptive Clipping)" + return rep + + def initialize_parameters( + self, client_manager: ClientManager + ) -> Optional[Parameters]: + """Initialize global model parameters using given strategy.""" + return self.strategy.initialize_parameters(client_manager) + + def configure_fit( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, FitIns]]: + """Configure the next round of training.""" + additional_config = {KEY_CLIPPING_NORM: self.clipping_norm} + inner_strategy_config_result = self.strategy.configure_fit( + server_round, parameters, client_manager + ) + for _, fit_ins in inner_strategy_config_result: + fit_ins.config.update(additional_config) + + return inner_strategy_config_result + + def configure_evaluate( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, EvaluateIns]]: + """Configure the next round of evaluation.""" + return self.strategy.configure_evaluate( + server_round, parameters, client_manager + ) + + def aggregate_fit( + self, + server_round: int, + results: List[Tuple[ClientProxy, FitRes]], + failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], + ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + """Aggregate training results and update clip norms.""" + if failures: + return None, {} + + if len(results) != self.num_sampled_clients: + log( + WARNING, + CLIENTS_DISCREPANCY_WARNING, + len(results), + self.num_sampled_clients, + ) + + aggregated_params, metrics = self.strategy.aggregate_fit( + server_round, results, failures + ) + self._update_clip_norm(results) + + # Add Gaussian noise to the aggregated parameters + if aggregated_params: + aggregated_params = add_gaussian_noise_to_params( + aggregated_params, + self.noise_multiplier, + self.clipping_norm, + self.num_sampled_clients, + ) + + return aggregated_params, metrics + + def _update_clip_norm(self, results: List[Tuple[ClientProxy, FitRes]]) -> None: + # Calculate the number of clients which set the norm indicator bit + norm_bit_set_count = 0 + for client_proxy, fit_res in results: + if KEY_NORM_BIT not in fit_res.metrics: + raise KeyError( + f"{KEY_NORM_BIT} not returned by client with id {client_proxy.cid}." + ) + if fit_res.metrics[KEY_NORM_BIT]: + norm_bit_set_count += 1 + # Add noise to the count + noised_norm_bit_set_count = float( + np.random.normal(norm_bit_set_count, self.clipped_count_stddev) + ) + + noised_norm_bit_set_fraction = noised_norm_bit_set_count / len(results) + # Geometric update + self.clipping_norm *= math.exp( + -self.clip_norm_lr + * (noised_norm_bit_set_fraction - self.target_clipped_quantile) + ) + + def aggregate_evaluate( + self, + server_round: int, + results: List[Tuple[ClientProxy, EvaluateRes]], + failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> Tuple[Optional[float], Dict[str, Scalar]]: + """Aggregate evaluation losses using the given strategy.""" + return self.strategy.aggregate_evaluate(server_round, results, failures) + + def evaluate( + self, server_round: int, parameters: Parameters + ) -> Optional[Tuple[float, Dict[str, Scalar]]]: + """Evaluate model parameters using an evaluation function from the strategy.""" + return self.strategy.evaluate(server_round, parameters) diff --git a/src/py/flwr/server/strategy/dp_fixed_clipping.py b/src/py/flwr/server/strategy/dp_fixed_clipping.py new file mode 100644 index 000000000000..69930ce49c0b --- /dev/null +++ b/src/py/flwr/server/strategy/dp_fixed_clipping.py @@ -0,0 +1,339 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Central differential privacy with fixed clipping. + +Papers: https://arxiv.org/abs/1712.07557, https://arxiv.org/abs/1710.06963 +""" + + +from logging import WARNING +from typing import Dict, List, Optional, Tuple, Union + +from flwr.common import ( + EvaluateIns, + EvaluateRes, + FitIns, + FitRes, + NDArrays, + Parameters, + Scalar, + ndarrays_to_parameters, + parameters_to_ndarrays, +) +from flwr.common.differential_privacy import ( + add_gaussian_noise_to_params, + compute_clip_model_update, +) +from flwr.common.differential_privacy_constants import ( + CLIENTS_DISCREPANCY_WARNING, + KEY_CLIPPING_NORM, +) +from flwr.common.logger import log +from flwr.server.client_manager import ClientManager +from flwr.server.client_proxy import ClientProxy +from flwr.server.strategy.strategy import Strategy + + +class DifferentialPrivacyServerSideFixedClipping(Strategy): + """Strategy wrapper for central DP with server-side fixed clipping. + + Parameters + ---------- + strategy : Strategy + The strategy to which DP functionalities will be added by this wrapper. + noise_multiplier : float + The noise multiplier for the Gaussian mechanism for model updates. + A value of 1.0 or higher is recommended for strong privacy. + clipping_norm : float + The value of the clipping norm. + num_sampled_clients : int + The number of clients that are sampled on each round. + + Examples + -------- + Create a strategy: + + >>> strategy = fl.server.strategy.FedAvg( ... ) + + Wrap the strategy with the DifferentialPrivacyServerSideFixedClipping wrapper + + >>> dp_strategy = DifferentialPrivacyServerSideFixedClipping( + >>> strategy, cfg.noise_multiplier, cfg.clipping_norm, cfg.num_sampled_clients + >>> ) + """ + + # pylint: disable=too-many-arguments,too-many-instance-attributes + def __init__( + self, + strategy: Strategy, + noise_multiplier: float, + clipping_norm: float, + num_sampled_clients: int, + ) -> None: + super().__init__() + + self.strategy = strategy + + if noise_multiplier < 0: + raise ValueError("The noise multiplier should be a non-negative value.") + + if clipping_norm <= 0: + raise ValueError("The clipping norm should be a positive value.") + + if num_sampled_clients <= 0: + raise ValueError( + "The number of sampled clients should be a positive value." + ) + + self.noise_multiplier = noise_multiplier + self.clipping_norm = clipping_norm + self.num_sampled_clients = num_sampled_clients + + self.current_round_params: NDArrays = [] + + def __repr__(self) -> str: + """Compute a string representation of the strategy.""" + rep = "Differential Privacy Strategy Wrapper (Server-Side Fixed Clipping)" + return rep + + def initialize_parameters( + self, client_manager: ClientManager + ) -> Optional[Parameters]: + """Initialize global model parameters using given strategy.""" + return self.strategy.initialize_parameters(client_manager) + + def configure_fit( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, FitIns]]: + """Configure the next round of training.""" + self.current_round_params = parameters_to_ndarrays(parameters) + return self.strategy.configure_fit(server_round, parameters, client_manager) + + def configure_evaluate( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, EvaluateIns]]: + """Configure the next round of evaluation.""" + return self.strategy.configure_evaluate( + server_round, parameters, client_manager + ) + + def aggregate_fit( + self, + server_round: int, + results: List[Tuple[ClientProxy, FitRes]], + failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], + ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + """Compute the updates, clip, and pass them for aggregation. + + Afterward, add noise to the aggregated parameters. + """ + if failures: + return None, {} + + if len(results) != self.num_sampled_clients: + log( + WARNING, + CLIENTS_DISCREPANCY_WARNING, + len(results), + self.num_sampled_clients, + ) + for _, res in results: + param = parameters_to_ndarrays(res.parameters) + # Compute and clip update + compute_clip_model_update( + param, self.current_round_params, self.clipping_norm + ) + # Convert back to parameters + res.parameters = ndarrays_to_parameters(param) + + # Pass the new parameters for aggregation + aggregated_params, metrics = self.strategy.aggregate_fit( + server_round, results, failures + ) + + # Add Gaussian noise to the aggregated parameters + if aggregated_params: + aggregated_params = add_gaussian_noise_to_params( + aggregated_params, + self.noise_multiplier, + self.clipping_norm, + self.num_sampled_clients, + ) + + return aggregated_params, metrics + + def aggregate_evaluate( + self, + server_round: int, + results: List[Tuple[ClientProxy, EvaluateRes]], + failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> Tuple[Optional[float], Dict[str, Scalar]]: + """Aggregate evaluation losses using the given strategy.""" + return self.strategy.aggregate_evaluate(server_round, results, failures) + + def evaluate( + self, server_round: int, parameters: Parameters + ) -> Optional[Tuple[float, Dict[str, Scalar]]]: + """Evaluate model parameters using an evaluation function from the strategy.""" + return self.strategy.evaluate(server_round, parameters) + + +class DifferentialPrivacyClientSideFixedClipping(Strategy): + """Strategy wrapper for central DP with client-side fixed clipping. + + Use `fixedclipping_mod` modifier at the client side. + + In comparison to `DifferentialPrivacyServerSideFixedClipping`, + which performs clipping on the server-side, `DifferentialPrivacyClientSideFixedClipping` + expects clipping to happen on the client-side, usually by using the built-in + `fixedclipping_mod`. + + Parameters + ---------- + strategy : Strategy + The strategy to which DP functionalities will be added by this wrapper. + noise_multiplier : float + The noise multiplier for the Gaussian mechanism for model updates. + A value of 1.0 or higher is recommended for strong privacy. + clipping_norm : float + The value of the clipping norm. + num_sampled_clients : int + The number of clients that are sampled on each round. + + Examples + -------- + Create a strategy: + + >>> strategy = fl.server.strategy.FedAvg(...) + + Wrap the strategy with the `DifferentialPrivacyClientSideFixedClipping` wrapper: + + >>> DifferentialPrivacyClientSideFixedClipping( + >>> strategy, cfg.noise_multiplier, cfg.clipping_norm, cfg.num_sampled_clients + >>> ) + + On the client, add the `fixedclipping_mod` to the client-side mods: + + >>> app = fl.client.ClientApp( + >>> client_fn=client_fn, mods=[fixedclipping_mod] + >>> ) + """ + + # pylint: disable=too-many-arguments,too-many-instance-attributes + def __init__( + self, + strategy: Strategy, + noise_multiplier: float, + clipping_norm: float, + num_sampled_clients: int, + ) -> None: + super().__init__() + + self.strategy = strategy + + if noise_multiplier < 0: + raise ValueError("The noise multiplier should be a non-negative value.") + + if clipping_norm <= 0: + raise ValueError("The clipping threshold should be a positive value.") + + if num_sampled_clients <= 0: + raise ValueError( + "The number of sampled clients should be a positive value." + ) + + self.noise_multiplier = noise_multiplier + self.clipping_norm = clipping_norm + self.num_sampled_clients = num_sampled_clients + + def __repr__(self) -> str: + """Compute a string representation of the strategy.""" + rep = "Differential Privacy Strategy Wrapper (Client-Side Fixed Clipping)" + return rep + + def initialize_parameters( + self, client_manager: ClientManager + ) -> Optional[Parameters]: + """Initialize global model parameters using given strategy.""" + return self.strategy.initialize_parameters(client_manager) + + def configure_fit( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, FitIns]]: + """Configure the next round of training.""" + additional_config = {KEY_CLIPPING_NORM: self.clipping_norm} + inner_strategy_config_result = self.strategy.configure_fit( + server_round, parameters, client_manager + ) + for _, fit_ins in inner_strategy_config_result: + fit_ins.config.update(additional_config) + + return inner_strategy_config_result + + def configure_evaluate( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, EvaluateIns]]: + """Configure the next round of evaluation.""" + return self.strategy.configure_evaluate( + server_round, parameters, client_manager + ) + + def aggregate_fit( + self, + server_round: int, + results: List[Tuple[ClientProxy, FitRes]], + failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], + ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + """Add noise to the aggregated parameters.""" + if failures: + return None, {} + + if len(results) != self.num_sampled_clients: + log( + WARNING, + CLIENTS_DISCREPANCY_WARNING, + len(results), + self.num_sampled_clients, + ) + + # Pass the new parameters for aggregation + aggregated_params, metrics = self.strategy.aggregate_fit( + server_round, results, failures + ) + + # Add Gaussian noise to the aggregated parameters + if aggregated_params: + aggregated_params = add_gaussian_noise_to_params( + aggregated_params, + self.noise_multiplier, + self.clipping_norm, + self.num_sampled_clients, + ) + return aggregated_params, metrics + + def aggregate_evaluate( + self, + server_round: int, + results: List[Tuple[ClientProxy, EvaluateRes]], + failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> Tuple[Optional[float], Dict[str, Scalar]]: + """Aggregate evaluation losses using the given strategy.""" + return self.strategy.aggregate_evaluate(server_round, results, failures) + + def evaluate( + self, server_round: int, parameters: Parameters + ) -> Optional[Tuple[float, Dict[str, Scalar]]]: + """Evaluate model parameters using an evaluation function from the strategy.""" + return self.strategy.evaluate(server_round, parameters) diff --git a/src/py/flwr/server/strategy/dpfedavg_adaptive.py b/src/py/flwr/server/strategy/dpfedavg_adaptive.py index 8b3278cc9ba0..a908679ed668 100644 --- a/src/py/flwr/server/strategy/dpfedavg_adaptive.py +++ b/src/py/flwr/server/strategy/dpfedavg_adaptive.py @@ -24,6 +24,7 @@ import numpy as np from flwr.common import FitIns, FitRes, Parameters, Scalar +from flwr.common.logger import warn_deprecated_feature from flwr.server.client_manager import ClientManager from flwr.server.client_proxy import ClientProxy from flwr.server.strategy.dpfedavg_fixed import DPFedAvgFixed @@ -31,7 +32,12 @@ class DPFedAvgAdaptive(DPFedAvgFixed): - """Wrapper for configuring a Strategy for DP with Adaptive Clipping.""" + """Wrapper for configuring a Strategy for DP with Adaptive Clipping. + + Warning + ------- + This class is deprecated and will be removed in a future release. + """ # pylint: disable=too-many-arguments,too-many-instance-attributes def __init__( @@ -45,6 +51,7 @@ def __init__( clip_norm_target_quantile: float = 0.5, clip_count_stddev: Optional[float] = None, ) -> None: + warn_deprecated_feature("`DPFedAvgAdaptive` wrapper") super().__init__( strategy=strategy, num_sampled_clients=num_sampled_clients, diff --git a/src/py/flwr/server/strategy/dpfedavg_fixed.py b/src/py/flwr/server/strategy/dpfedavg_fixed.py index f2f1c206f3de..c54379fc7087 100644 --- a/src/py/flwr/server/strategy/dpfedavg_fixed.py +++ b/src/py/flwr/server/strategy/dpfedavg_fixed.py @@ -17,11 +17,11 @@ Paper: arxiv.org/pdf/1710.06963.pdf """ - from typing import Dict, List, Optional, Tuple, Union from flwr.common import EvaluateIns, EvaluateRes, FitIns, FitRes, Parameters, Scalar from flwr.common.dp import add_gaussian_noise +from flwr.common.logger import warn_deprecated_feature from flwr.common.parameter import ndarrays_to_parameters, parameters_to_ndarrays from flwr.server.client_manager import ClientManager from flwr.server.client_proxy import ClientProxy @@ -29,7 +29,12 @@ class DPFedAvgFixed(Strategy): - """Wrapper for configuring a Strategy for DP with Fixed Clipping.""" + """Wrapper for configuring a Strategy for DP with Fixed Clipping. + + Warning + ------- + This class is deprecated and will be removed in a future release. + """ # pylint: disable=too-many-arguments,too-many-instance-attributes def __init__( @@ -40,6 +45,7 @@ def __init__( noise_multiplier: float = 1, server_side_noising: bool = True, ) -> None: + warn_deprecated_feature("`DPFedAvgFixed` wrapper") super().__init__() self.strategy = strategy # Doing fixed-size subsampling as in https://arxiv.org/abs/1905.03871. @@ -98,9 +104,9 @@ def configure_fit( """ additional_config = {"dpfedavg_clip_norm": self.clip_norm} if not self.server_side_noising: - additional_config[ - "dpfedavg_noise_stddev" - ] = self._calc_client_noise_stddev() + additional_config["dpfedavg_noise_stddev"] = ( + self._calc_client_noise_stddev() + ) client_instructions = self.strategy.configure_fit( server_round, parameters, client_manager diff --git a/src/py/flwr/server/strategy/fedadagrad_test.py b/src/py/flwr/server/strategy/fedadagrad_test.py index b3380a5be2f9..0c966442ecaf 100644 --- a/src/py/flwr/server/strategy/fedadagrad_test.py +++ b/src/py/flwr/server/strategy/fedadagrad_test.py @@ -30,7 +30,7 @@ parameters_to_ndarrays, ) from flwr.server.client_proxy import ClientProxy -from flwr.server.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy +from flwr.server.superlink.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy from .fedadagrad import FedAdagrad diff --git a/src/py/flwr/server/strategy/fedavg.py b/src/py/flwr/server/strategy/fedavg.py index e4b126823fb6..3b9b2640c2b5 100644 --- a/src/py/flwr/server/strategy/fedavg.py +++ b/src/py/flwr/server/strategy/fedavg.py @@ -84,6 +84,8 @@ class FedAvg(Strategy): Metrics aggregation function, optional. evaluate_metrics_aggregation_fn : Optional[MetricsAggregationFn] Metrics aggregation function, optional. + inplace : bool (default: True) + Enable (True) or disable (False) in-place aggregation of model updates. """ # pylint: disable=too-many-arguments,too-many-instance-attributes, line-too-long diff --git a/src/py/flwr/server/strategy/fedmedian_test.py b/src/py/flwr/server/strategy/fedmedian_test.py index 180503df6c80..57cf08d8c01d 100644 --- a/src/py/flwr/server/strategy/fedmedian_test.py +++ b/src/py/flwr/server/strategy/fedmedian_test.py @@ -30,7 +30,7 @@ parameters_to_ndarrays, ) from flwr.server.client_proxy import ClientProxy -from flwr.server.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy +from flwr.server.superlink.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy from .fedmedian import FedMedian diff --git a/src/py/flwr/server/strategy/fedxgb_bagging.py b/src/py/flwr/server/strategy/fedxgb_bagging.py index cafb466c2e8b..a8e8adddafbb 100644 --- a/src/py/flwr/server/strategy/fedxgb_bagging.py +++ b/src/py/flwr/server/strategy/fedxgb_bagging.py @@ -44,6 +44,11 @@ def __init__( self.global_model: Optional[bytes] = None super().__init__(**kwargs) + def __repr__(self) -> str: + """Compute a string representation of the strategy.""" + rep = f"FedXgbBagging(accept_failures={self.accept_failures})" + return rep + def aggregate_fit( self, server_round: int, diff --git a/src/py/flwr/server/strategy/fedxgb_cyclic.py b/src/py/flwr/server/strategy/fedxgb_cyclic.py index e2707b02d19d..2605daab29f4 100644 --- a/src/py/flwr/server/strategy/fedxgb_cyclic.py +++ b/src/py/flwr/server/strategy/fedxgb_cyclic.py @@ -37,6 +37,11 @@ def __init__( self.global_model: Optional[bytes] = None super().__init__(**kwargs) + def __repr__(self) -> str: + """Compute a string representation of the strategy.""" + rep = f"FedXgbCyclic(accept_failures={self.accept_failures})" + return rep + def aggregate_fit( self, server_round: int, diff --git a/src/py/flwr/server/strategy/krum_test.py b/src/py/flwr/server/strategy/krum_test.py index 81e59230739a..653dc9a8475d 100644 --- a/src/py/flwr/server/strategy/krum_test.py +++ b/src/py/flwr/server/strategy/krum_test.py @@ -30,7 +30,7 @@ parameters_to_ndarrays, ) from flwr.server.client_proxy import ClientProxy -from flwr.server.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy +from flwr.server.superlink.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy from .krum import Krum diff --git a/src/py/flwr/server/strategy/multikrum_test.py b/src/py/flwr/server/strategy/multikrum_test.py index 1469db104252..f874dc2f9800 100644 --- a/src/py/flwr/server/strategy/multikrum_test.py +++ b/src/py/flwr/server/strategy/multikrum_test.py @@ -30,7 +30,7 @@ parameters_to_ndarrays, ) from flwr.server.client_proxy import ClientProxy -from flwr.server.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy +from flwr.server.superlink.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy from .krum import Krum diff --git a/src/py/flwr/server/superlink/__init__.py b/src/py/flwr/server/superlink/__init__.py new file mode 100644 index 000000000000..94102100de26 --- /dev/null +++ b/src/py/flwr/server/superlink/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower SuperLink.""" diff --git a/src/py/flwr/driver/__init__.py b/src/py/flwr/server/superlink/driver/__init__.py similarity index 78% rename from src/py/flwr/driver/__init__.py rename to src/py/flwr/server/superlink/driver/__init__.py index 1c3b09cc334b..2bfe63e6065f 100644 --- a/src/py/flwr/driver/__init__.py +++ b/src/py/flwr/server/superlink/driver/__init__.py @@ -12,15 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Flower driver SDK.""" - - -from .app import start_driver -from .driver import Driver -from .grpc_driver import GrpcDriver - -__all__ = [ - "Driver", - "GrpcDriver", - "start_driver", -] +"""Flower driver service.""" diff --git a/src/py/flwr/server/superlink/driver/driver_grpc.py b/src/py/flwr/server/superlink/driver/driver_grpc.py new file mode 100644 index 000000000000..f74000bc59c4 --- /dev/null +++ b/src/py/flwr/server/superlink/driver/driver_grpc.py @@ -0,0 +1,54 @@ +# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Driver gRPC API.""" + +from logging import INFO +from typing import Optional, Tuple + +import grpc + +from flwr.common import GRPC_MAX_MESSAGE_LENGTH +from flwr.common.logger import log +from flwr.proto.driver_pb2_grpc import ( # pylint: disable=E0611 + add_DriverServicer_to_server, +) +from flwr.server.superlink.state import StateFactory + +from ..fleet.grpc_bidi.grpc_server import generic_create_grpc_server +from .driver_servicer import DriverServicer + + +def run_driver_api_grpc( + address: str, + state_factory: StateFactory, + certificates: Optional[Tuple[bytes, bytes, bytes]], +) -> grpc.Server: + """Run Driver API (gRPC, request-response).""" + # Create Driver API gRPC server + driver_servicer: grpc.Server = DriverServicer( + state_factory=state_factory, + ) + driver_add_servicer_to_server_fn = add_DriverServicer_to_server + driver_grpc_server = generic_create_grpc_server( + servicer_and_add_fn=(driver_servicer, driver_add_servicer_to_server_fn), + server_address=address, + max_message_length=GRPC_MAX_MESSAGE_LENGTH, + certificates=certificates, + ) + + log(INFO, "Flower ECE: Starting Driver API (gRPC-rere) on %s", address) + driver_grpc_server.start() + + return driver_grpc_server diff --git a/src/py/flwr/server/driver/driver_servicer.py b/src/py/flwr/server/superlink/driver/driver_servicer.py similarity index 92% rename from src/py/flwr/server/driver/driver_servicer.py rename to src/py/flwr/server/superlink/driver/driver_servicer.py index 275cc8ac6a03..59e51ef52d8e 100644 --- a/src/py/flwr/server/driver/driver_servicer.py +++ b/src/py/flwr/server/superlink/driver/driver_servicer.py @@ -15,7 +15,7 @@ """Driver API servicer.""" -from logging import INFO +from logging import DEBUG, INFO from typing import List, Optional, Set from uuid import UUID @@ -35,7 +35,7 @@ ) from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611 -from flwr.server.state import State, StateFactory +from flwr.server.superlink.state import State, StateFactory from flwr.server.utils.validator import validate_task_ins_or_res @@ -49,7 +49,7 @@ def GetNodes( self, request: GetNodesRequest, context: grpc.ServicerContext ) -> GetNodesResponse: """Get available nodes.""" - log(INFO, "DriverServicer.GetNodes") + log(DEBUG, "DriverServicer.GetNodes") state: State = self.state_factory.state() all_ids: Set[int] = state.get_nodes(request.run_id) nodes: List[Node] = [ @@ -70,7 +70,7 @@ def PushTaskIns( self, request: PushTaskInsRequest, context: grpc.ServicerContext ) -> PushTaskInsResponse: """Push a set of TaskIns.""" - log(INFO, "DriverServicer.PushTaskIns") + log(DEBUG, "DriverServicer.PushTaskIns") # Validate request _raise_if(len(request.task_ins_list) == 0, "`task_ins_list` must not be empty") @@ -95,7 +95,7 @@ def PullTaskRes( self, request: PullTaskResRequest, context: grpc.ServicerContext ) -> PullTaskResResponse: """Pull a set of TaskRes.""" - log(INFO, "DriverServicer.PullTaskRes") + log(DEBUG, "DriverServicer.PullTaskRes") # Convert each task_id str to UUID task_ids: Set[UUID] = {UUID(task_id) for task_id in request.task_ids} @@ -105,7 +105,7 @@ def PullTaskRes( # Register callback def on_rpc_done() -> None: - log(INFO, "DriverServicer.PullTaskRes callback: delete TaskIns/TaskRes") + log(DEBUG, "DriverServicer.PullTaskRes callback: delete TaskIns/TaskRes") if context.is_active(): return diff --git a/src/py/flwr/server/driver/driver_servicer_test.py b/src/py/flwr/server/superlink/driver/driver_servicer_test.py similarity index 95% rename from src/py/flwr/server/driver/driver_servicer_test.py rename to src/py/flwr/server/superlink/driver/driver_servicer_test.py index c432c026a632..99f7cc007a89 100644 --- a/src/py/flwr/server/driver/driver_servicer_test.py +++ b/src/py/flwr/server/superlink/driver/driver_servicer_test.py @@ -15,7 +15,7 @@ """DriverServicer tests.""" -from flwr.server.driver.driver_servicer import _raise_if +from flwr.server.superlink.driver.driver_servicer import _raise_if # pylint: disable=broad-except diff --git a/src/py/flwr/server/fleet/__init__.py b/src/py/flwr/server/superlink/fleet/__init__.py similarity index 100% rename from src/py/flwr/server/fleet/__init__.py rename to src/py/flwr/server/superlink/fleet/__init__.py diff --git a/src/py/flwr/server/fleet/grpc_bidi/__init__.py b/src/py/flwr/server/superlink/fleet/grpc_bidi/__init__.py similarity index 100% rename from src/py/flwr/server/fleet/grpc_bidi/__init__.py rename to src/py/flwr/server/superlink/fleet/grpc_bidi/__init__.py diff --git a/src/py/flwr/server/fleet/grpc_bidi/flower_service_servicer.py b/src/py/flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py similarity index 91% rename from src/py/flwr/server/fleet/grpc_bidi/flower_service_servicer.py rename to src/py/flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py index 6eccb056390a..6f94ea844e38 100644 --- a/src/py/flwr/server/fleet/grpc_bidi/flower_service_servicer.py +++ b/src/py/flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py @@ -18,7 +18,7 @@ - https://github.com/grpc/grpc/blob/master/doc/statuscodes.md """ - +import uuid from typing import Callable, Iterator import grpc @@ -30,8 +30,12 @@ ServerMessage, ) from flwr.server.client_manager import ClientManager -from flwr.server.fleet.grpc_bidi.grpc_bridge import GrpcBridge, InsWrapper, ResWrapper -from flwr.server.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy +from flwr.server.superlink.fleet.grpc_bidi.grpc_bridge import ( + GrpcBridge, + InsWrapper, + ResWrapper, +) +from flwr.server.superlink.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy def default_bridge_factory() -> GrpcBridge: @@ -91,9 +95,12 @@ def Join( # pylint: disable=invalid-name wrapping the actual message - The `Join` method is (pretty much) unaware of the protocol """ - peer: str = context.peer() + # When running Flower behind a proxy, the peer can be the same for + # different clients, so instead of `cid: str = context.peer()` we + # use a `UUID4` that is unique. + cid: str = uuid.uuid4().hex bridge = self.grpc_bridge_factory() - client_proxy = self.client_proxy_factory(peer, bridge) + client_proxy = self.client_proxy_factory(cid, bridge) is_success = register_client_proxy(self.client_manager, client_proxy, context) if is_success: diff --git a/src/py/flwr/server/fleet/grpc_bidi/flower_service_servicer_test.py b/src/py/flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer_test.py similarity index 91% rename from src/py/flwr/server/fleet/grpc_bidi/flower_service_servicer_test.py rename to src/py/flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer_test.py index b5c3f504af03..bd93554a6a32 100644 --- a/src/py/flwr/server/fleet/grpc_bidi/flower_service_servicer_test.py +++ b/src/py/flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer_test.py @@ -16,21 +16,23 @@ import unittest +import uuid from unittest.mock import MagicMock, call from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 ClientMessage, ServerMessage, ) -from flwr.server.fleet.grpc_bidi.flower_service_servicer import ( +from flwr.server.superlink.fleet.grpc_bidi.flower_service_servicer import ( FlowerServiceServicer, register_client_proxy, ) -from flwr.server.fleet.grpc_bidi.grpc_bridge import InsWrapper, ResWrapper +from flwr.server.superlink.fleet.grpc_bidi.grpc_bridge import InsWrapper, ResWrapper CLIENT_MESSAGE = ClientMessage() SERVER_MESSAGE = ServerMessage() -CLIENT_CID = "some_client_cid" + +CID: str = uuid.uuid4().hex class FlowerServiceServicerTestCase(unittest.TestCase): @@ -42,7 +44,6 @@ def setUp(self) -> None: """Create mocks for tests.""" # Mock for the gRPC context argument self.context_mock = MagicMock() - self.context_mock.peer.return_value = CLIENT_CID # Define client_messages to be processed by FlowerServiceServicer instance self.client_messages = [CLIENT_MESSAGE for _ in range(5)] @@ -70,7 +71,7 @@ def setUp(self) -> None: # Create a GrpcClientProxy mock which we will use to test if correct # methods where called and client_messages are getting passed to it self.grpc_client_proxy_mock = MagicMock() - self.grpc_client_proxy_mock.cid = CLIENT_CID + self.grpc_client_proxy_mock.cid = CID self.client_proxy_factory_mock = MagicMock() self.client_proxy_factory_mock.return_value = self.grpc_client_proxy_mock @@ -127,11 +128,7 @@ def test_join(self) -> None: num_server_messages += 1 assert len(self.client_messages) == num_server_messages - assert self.grpc_client_proxy_mock.cid == CLIENT_CID - - self.client_proxy_factory_mock.assert_called_once_with( - CLIENT_CID, self.grpc_bridge_mock - ) + assert self.grpc_client_proxy_mock.cid == CID # Check if the client was registered with the client_manager self.client_manager_mock.register.assert_called_once_with( diff --git a/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge.py b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py similarity index 100% rename from src/py/flwr/server/fleet/grpc_bidi/grpc_bridge.py rename to src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py diff --git a/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge_test.py b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_bridge_test.py similarity index 98% rename from src/py/flwr/server/fleet/grpc_bidi/grpc_bridge_test.py rename to src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_bridge_test.py index 6527c45d7d6c..f7c236acd7a1 100644 --- a/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge_test.py +++ b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_bridge_test.py @@ -23,7 +23,7 @@ ClientMessage, ServerMessage, ) -from flwr.server.fleet.grpc_bidi.grpc_bridge import ( +from flwr.server.superlink.fleet.grpc_bidi.grpc_bridge import ( GrpcBridge, GrpcBridgeClosed, InsWrapper, diff --git a/src/py/flwr/server/fleet/grpc_bidi/grpc_client_proxy.py b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py similarity index 94% rename from src/py/flwr/server/fleet/grpc_bidi/grpc_client_proxy.py rename to src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py index 46185896561e..ac62ad014950 100644 --- a/src/py/flwr/server/fleet/grpc_bidi/grpc_client_proxy.py +++ b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py @@ -24,7 +24,11 @@ ServerMessage, ) from flwr.server.client_proxy import ClientProxy -from flwr.server.fleet.grpc_bidi.grpc_bridge import GrpcBridge, InsWrapper, ResWrapper +from flwr.server.superlink.fleet.grpc_bidi.grpc_bridge import ( + GrpcBridge, + InsWrapper, + ResWrapper, +) class GrpcClientProxy(ClientProxy): @@ -42,6 +46,7 @@ def get_properties( self, ins: common.GetPropertiesIns, timeout: Optional[float], + group_id: Optional[int], ) -> common.GetPropertiesRes: """Request client's set of internal properties.""" get_properties_msg = serde.get_properties_ins_to_proto(ins) @@ -61,6 +66,7 @@ def get_parameters( self, ins: common.GetParametersIns, timeout: Optional[float], + group_id: Optional[int], ) -> common.GetParametersRes: """Return the current local model parameters.""" get_parameters_msg = serde.get_parameters_ins_to_proto(ins) @@ -80,6 +86,7 @@ def fit( self, ins: common.FitIns, timeout: Optional[float], + group_id: Optional[int], ) -> common.FitRes: """Refine the provided parameters using the locally held dataset.""" fit_ins_msg = serde.fit_ins_to_proto(ins) @@ -98,6 +105,7 @@ def evaluate( self, ins: common.EvaluateIns, timeout: Optional[float], + group_id: Optional[int], ) -> common.EvaluateRes: """Evaluate the provided parameters using the locally held dataset.""" evaluate_msg = serde.evaluate_ins_to_proto(ins) @@ -115,6 +123,7 @@ def reconnect( self, ins: common.ReconnectIns, timeout: Optional[float], + group_id: Optional[int], ) -> common.DisconnectRes: """Disconnect and (optionally) reconnect later.""" reconnect_ins_msg = serde.reconnect_ins_to_proto(ins) diff --git a/src/py/flwr/server/fleet/grpc_bidi/grpc_client_proxy_test.py b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy_test.py similarity index 92% rename from src/py/flwr/server/fleet/grpc_bidi/grpc_client_proxy_test.py rename to src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy_test.py index 1a417ae433d5..e7077dfd39ae 100644 --- a/src/py/flwr/server/fleet/grpc_bidi/grpc_client_proxy_test.py +++ b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy_test.py @@ -27,8 +27,8 @@ Parameters, Scalar, ) -from flwr.server.fleet.grpc_bidi.grpc_bridge import ResWrapper -from flwr.server.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy +from flwr.server.superlink.fleet.grpc_bidi.grpc_bridge import ResWrapper +from flwr.server.superlink.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy MESSAGE_PARAMETERS = Parameters(tensors=[], tensor_type="np") MESSAGE_FIT_RES = ClientMessage( @@ -71,7 +71,7 @@ def test_get_parameters(self) -> None: # Execute value: flwr.common.GetParametersRes = client.get_parameters( - ins=get_parameters_ins, timeout=None + ins=get_parameters_ins, timeout=None, group_id=0 ) # Assert @@ -88,7 +88,7 @@ def test_fit(self) -> None: ins: flwr.common.FitIns = flwr.common.FitIns(parameters, {}) # Execute - fit_res = client.fit(ins=ins, timeout=None) + fit_res = client.fit(ins=ins, timeout=None, group_id=0) # Assert assert fit_res.parameters.tensor_type == "np" @@ -106,7 +106,7 @@ def test_evaluate(self) -> None: evaluate_ins: flwr.common.EvaluateIns = flwr.common.EvaluateIns(parameters, {}) # Execute - evaluate_res = client.evaluate(evaluate_ins, timeout=None) + evaluate_res = client.evaluate(evaluate_ins, timeout=None, group_id=1) # Assert assert (0, 0.0) == ( @@ -127,7 +127,9 @@ def test_get_properties(self) -> None: ) # Execute - value: flwr.common.GetPropertiesRes = client.get_properties(ins, timeout=None) + value: flwr.common.GetPropertiesRes = client.get_properties( + ins, timeout=None, group_id=0 + ) # Assert assert value.properties["tensor_type"] == "numpy.ndarray" diff --git a/src/py/flwr/server/fleet/grpc_bidi/grpc_server.py b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server.py similarity index 97% rename from src/py/flwr/server/fleet/grpc_bidi/grpc_server.py rename to src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server.py index e05df88dcd12..82f049844bd6 100644 --- a/src/py/flwr/server/fleet/grpc_bidi/grpc_server.py +++ b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server.py @@ -28,9 +28,11 @@ add_FlowerServiceServicer_to_server, ) from flwr.server.client_manager import ClientManager -from flwr.server.driver.driver_servicer import DriverServicer -from flwr.server.fleet.grpc_bidi.flower_service_servicer import FlowerServiceServicer -from flwr.server.fleet.grpc_rere.fleet_servicer import FleetServicer +from flwr.server.superlink.driver.driver_servicer import DriverServicer +from flwr.server.superlink.fleet.grpc_bidi.flower_service_servicer import ( + FlowerServiceServicer, +) +from flwr.server.superlink.fleet.grpc_rere.fleet_servicer import FleetServicer INVALID_CERTIFICATES_ERR_MSG = """ When setting any of root_certificate, certificate, or private_key, diff --git a/src/py/flwr/server/fleet/grpc_bidi/grpc_server_test.py b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server_test.py similarity index 96% rename from src/py/flwr/server/fleet/grpc_bidi/grpc_server_test.py rename to src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server_test.py index 4cd093d6ab0f..8afa37515950 100644 --- a/src/py/flwr/server/fleet/grpc_bidi/grpc_server_test.py +++ b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server_test.py @@ -23,12 +23,12 @@ from typing import Tuple, cast from flwr.server.client_manager import SimpleClientManager -from flwr.server.fleet.grpc_bidi.grpc_server import ( +from flwr.server.superlink.fleet.grpc_bidi.grpc_server import ( start_grpc_server, valid_certificates, ) -root_dir = dirname(abspath(join(__file__, "../../../../../.."))) +root_dir = dirname(abspath(join(__file__, "../../../../../../.."))) def load_certificates() -> Tuple[str, str, str]: diff --git a/src/py/flwr/server/fleet/grpc_rere/__init__.py b/src/py/flwr/server/superlink/fleet/grpc_rere/__init__.py similarity index 100% rename from src/py/flwr/server/fleet/grpc_rere/__init__.py rename to src/py/flwr/server/superlink/fleet/grpc_rere/__init__.py diff --git a/src/py/flwr/server/fleet/grpc_rere/fleet_servicer.py b/src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py similarity index 84% rename from src/py/flwr/server/fleet/grpc_rere/fleet_servicer.py rename to src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py index b12f365e898c..278474477379 100644 --- a/src/py/flwr/server/fleet/grpc_rere/fleet_servicer.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py @@ -31,15 +31,15 @@ PushTaskResRequest, PushTaskResResponse, ) -from flwr.server.fleet.message_handler import message_handler -from flwr.server.state import State +from flwr.server.superlink.fleet.message_handler import message_handler +from flwr.server.superlink.state import StateFactory class FleetServicer(fleet_pb2_grpc.FleetServicer): """Fleet API servicer.""" - def __init__(self, state: State) -> None: - self.state = state + def __init__(self, state_factory: StateFactory) -> None: + self.state_factory = state_factory def CreateNode( self, request: CreateNodeRequest, context: grpc.ServicerContext @@ -48,7 +48,7 @@ def CreateNode( log(INFO, "FleetServicer.CreateNode") return message_handler.create_node( request=request, - state=self.state, + state=self.state_factory.state(), ) def DeleteNode( @@ -58,7 +58,7 @@ def DeleteNode( log(INFO, "FleetServicer.DeleteNode") return message_handler.delete_node( request=request, - state=self.state, + state=self.state_factory.state(), ) def PullTaskIns( @@ -68,7 +68,7 @@ def PullTaskIns( log(INFO, "FleetServicer.PullTaskIns") return message_handler.pull_task_ins( request=request, - state=self.state, + state=self.state_factory.state(), ) def PushTaskRes( @@ -78,5 +78,5 @@ def PushTaskRes( log(INFO, "FleetServicer.PushTaskRes") return message_handler.push_task_res( request=request, - state=self.state, + state=self.state_factory.state(), ) diff --git a/src/py/flwr/server/fleet/message_handler/__init__.py b/src/py/flwr/server/superlink/fleet/message_handler/__init__.py similarity index 100% rename from src/py/flwr/server/fleet/message_handler/__init__.py rename to src/py/flwr/server/superlink/fleet/message_handler/__init__.py diff --git a/src/py/flwr/server/fleet/message_handler/message_handler.py b/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py similarity index 98% rename from src/py/flwr/server/fleet/message_handler/message_handler.py rename to src/py/flwr/server/superlink/fleet/message_handler/message_handler.py index 8d451c896ed9..5fe815180823 100644 --- a/src/py/flwr/server/fleet/message_handler/message_handler.py +++ b/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py @@ -31,7 +31,7 @@ ) from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 -from flwr.server.state import State +from flwr.server.superlink.state import State def create_node( diff --git a/src/py/flwr/server/fleet/message_handler/message_handler_test.py b/src/py/flwr/server/superlink/fleet/message_handler/message_handler_test.py similarity index 100% rename from src/py/flwr/server/fleet/message_handler/message_handler_test.py rename to src/py/flwr/server/superlink/fleet/message_handler/message_handler_test.py diff --git a/src/py/flwr/server/fleet/rest_rere/__init__.py b/src/py/flwr/server/superlink/fleet/rest_rere/__init__.py similarity index 100% rename from src/py/flwr/server/fleet/rest_rere/__init__.py rename to src/py/flwr/server/superlink/fleet/rest_rere/__init__.py diff --git a/src/py/flwr/server/fleet/rest_rere/rest_api.py b/src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py similarity index 98% rename from src/py/flwr/server/fleet/rest_rere/rest_api.py rename to src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py index b815558cb099..b022b34c68c8 100644 --- a/src/py/flwr/server/fleet/rest_rere/rest_api.py +++ b/src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py @@ -24,8 +24,8 @@ PullTaskInsRequest, PushTaskResRequest, ) -from flwr.server.fleet.message_handler import message_handler -from flwr.server.state import State +from flwr.server.superlink.fleet.message_handler import message_handler +from flwr.server.superlink.state import State try: from starlette.applications import Starlette diff --git a/src/py/flwr/server/superlink/fleet/vce/__init__.py b/src/py/flwr/server/superlink/fleet/vce/__init__.py new file mode 100644 index 000000000000..57d39688b527 --- /dev/null +++ b/src/py/flwr/server/superlink/fleet/vce/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Fleet Simulation Engine side.""" + +from .vce_api import start_vce + +__all__ = [ + "start_vce", +] diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/__init__.py b/src/py/flwr/server/superlink/fleet/vce/backend/__init__.py new file mode 100644 index 000000000000..d751cf4bcae1 --- /dev/null +++ b/src/py/flwr/server/superlink/fleet/vce/backend/__init__.py @@ -0,0 +1,48 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Simulation Engine Backends.""" + +import importlib +from typing import Dict, Type + +from .backend import Backend, BackendConfig + +is_ray_installed = importlib.util.find_spec("ray") is not None + +# Mapping of supported backends +supported_backends: Dict[str, Type[Backend]] = {} + +# To log backend-specific error message when chosen backend isn't available +error_messages_backends: Dict[str, str] = {} + +if is_ray_installed: + from .raybackend import RayBackend + + supported_backends["ray"] = RayBackend +else: + error_messages_backends[ + "ray" + ] = """Unable to import module `ray`. + + To install the necessary dependencies, install `flwr` with the `simulation` extra: + + pip install -U flwr["simulation"] + """ + + +__all__ = [ + "Backend", + "BackendConfig", +] diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/backend.py b/src/py/flwr/server/superlink/fleet/vce/backend/backend.py new file mode 100644 index 000000000000..1d5e3a6a51ad --- /dev/null +++ b/src/py/flwr/server/superlink/fleet/vce/backend/backend.py @@ -0,0 +1,67 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Generic Backend class for Fleet API using the Simulation Engine.""" + + +from abc import ABC, abstractmethod +from typing import Callable, Dict, Tuple + +from flwr.client.client_app import ClientApp +from flwr.common.context import Context +from flwr.common.message import Message +from flwr.common.typing import ConfigsRecordValues + +BackendConfig = Dict[str, Dict[str, ConfigsRecordValues]] + + +class Backend(ABC): + """Abstract base class for a Simulation Engine Backend.""" + + def __init__(self, backend_config: BackendConfig, work_dir: str) -> None: + """Construct a backend.""" + + @abstractmethod + async def build(self) -> None: + """Build backend asynchronously. + + Different components need to be in place before workers in a backend are ready + to accept jobs. When this method finishes executing, the backend should be fully + ready to run jobs. + """ + + @property + def num_workers(self) -> int: + """Return number of workers in the backend. + + This is the number of TaskIns that can be processed concurrently. + """ + return 0 + + @abstractmethod + def is_worker_idle(self) -> bool: + """Report whether a backend worker is idle and can therefore run a ClientApp.""" + + @abstractmethod + async def terminate(self) -> None: + """Terminate backend.""" + + @abstractmethod + async def process_message( + self, + app: Callable[[], ClientApp], + message: Message, + context: Context, + ) -> Tuple[Message, Context]: + """Submit a job to the backend.""" diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py new file mode 100644 index 000000000000..8ef0d54622ae --- /dev/null +++ b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py @@ -0,0 +1,175 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Ray backend for the Fleet API using the Simulation Engine.""" + +import pathlib +from logging import ERROR, INFO +from typing import Callable, Dict, List, Tuple, Union + +import ray + +from flwr.client.client_app import ClientApp, LoadClientAppError +from flwr.common.context import Context +from flwr.common.logger import log +from flwr.common.message import Message +from flwr.simulation.ray_transport.ray_actor import ( + BasicActorPool, + ClientAppActor, + init_ray, +) +from flwr.simulation.ray_transport.utils import enable_tf_gpu_growth + +from .backend import Backend, BackendConfig + +ClientResourcesDict = Dict[str, Union[int, float]] + + +class RayBackend(Backend): + """A backend that submits jobs to a `BasicActorPool`.""" + + def __init__( + self, + backend_config: BackendConfig, + work_dir: str, + ) -> None: + """Prepare RayBackend by initialising Ray and creating the ActorPool.""" + log(INFO, "Initialising: %s", self.__class__.__name__) + log(INFO, "Backend config: %s", backend_config) + + if not pathlib.Path(work_dir).exists(): + raise ValueError(f"Specified work_dir {work_dir} does not exist.") + + # Init ray and append working dir if needed + runtime_env = ( + self._configure_runtime_env(work_dir=work_dir) if work_dir else None + ) + init_ray(runtime_env=runtime_env) + + # Validate client resources + self.client_resources_key = "client_resources" + + # Create actor pool + use_tf = backend_config.get("tensorflow", False) + actor_kwargs = {"on_actor_init_fn": enable_tf_gpu_growth} if use_tf else {} + + client_resources = self._validate_client_resources(config=backend_config) + self.pool = BasicActorPool( + actor_type=ClientAppActor, + client_resources=client_resources, + actor_kwargs=actor_kwargs, + ) + + def _configure_runtime_env(self, work_dir: str) -> Dict[str, Union[str, List[str]]]: + """Return list of files/subdirectories to exclude relative to work_dir. + + Without this, Ray will push everything to the Ray Cluster. + """ + runtime_env: Dict[str, Union[str, List[str]]] = {"working_dir": work_dir} + + excludes = [] + path = pathlib.Path(work_dir) + for p in path.rglob("*"): + # Exclude files need to be relative to the working_dir + if p.is_file() and not str(p).endswith(".py"): + excludes.append(str(p.relative_to(path))) + runtime_env["excludes"] = excludes + + return runtime_env + + def _validate_client_resources(self, config: BackendConfig) -> ClientResourcesDict: + client_resources_config = config.get(self.client_resources_key) + client_resources: ClientResourcesDict = {} + valid_types = (int, float) + if client_resources_config: + for k, v in client_resources_config.items(): + if not isinstance(k, str): + raise ValueError( + f"client resources keys are expected to be `str` but you used " + f"{type(k)} for `{k}`" + ) + if not isinstance(v, valid_types): + raise ValueError( + f"client resources are expected to be of type {valid_types} " + f"but found `{type(v)}` for key `{k}`", + ) + client_resources[k] = v + + else: + client_resources = {"num_cpus": 2, "num_gpus": 0.0} + log( + INFO, + "`%s` not specified in backend config. Applying default setting: %s", + self.client_resources_key, + client_resources, + ) + + return client_resources + + @property + def num_workers(self) -> int: + """Return number of actors in pool.""" + return self.pool.num_actors + + def is_worker_idle(self) -> bool: + """Report whether the pool has idle actors.""" + return self.pool.is_actor_available() + + async def build(self) -> None: + """Build pool of Ray actors that this backend will submit jobs to.""" + await self.pool.add_actors_to_pool(self.pool.actors_capacity) + log(INFO, "Constructed ActorPool with: %i actors", self.pool.num_actors) + + async def process_message( + self, + app: Callable[[], ClientApp], + message: Message, + context: Context, + ) -> Tuple[Message, Context]: + """Run ClientApp that process a given message. + + Return output message and updated context. + """ + partition_id = message.metadata.partition_id + + try: + # Submite a task to the pool + future = await self.pool.submit( + lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state), + (app, message, str(partition_id), context), + ) + + await future + + # Fetch result + ( + out_mssg, + updated_context, + ) = await self.pool.fetch_result_and_return_actor_to_pool(future) + + return out_mssg, updated_context + + except LoadClientAppError as load_ex: + log( + ERROR, + "An exception was raised when processing a message by %s", + self.__class__.__name__, + ) + raise load_ex + + async def terminate(self) -> None: + """Terminate all actors in actor pool.""" + await self.pool.terminate_all_actors() + ray.shutdown() + log(INFO, "Terminated %s", self.__class__.__name__) diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py new file mode 100644 index 000000000000..2610307bb749 --- /dev/null +++ b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py @@ -0,0 +1,216 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test for Ray backend for the Fleet API using the Simulation Engine.""" + +import asyncio +from math import pi +from pathlib import Path +from typing import Callable, Dict, Optional, Tuple, Union +from unittest import IsolatedAsyncioTestCase + +import ray + +from flwr.client import Client, NumPyClient +from flwr.client.client_app import ClientApp, LoadClientAppError +from flwr.common import ( + Config, + ConfigsRecord, + Context, + GetPropertiesIns, + Message, + MessageTypeLegacy, + Metadata, + RecordSet, + Scalar, +) +from flwr.common.object_ref import load_app +from flwr.common.recordset_compat import getpropertiesins_to_recordset +from flwr.server.superlink.fleet.vce.backend.raybackend import RayBackend + + +class DummyClient(NumPyClient): + """A dummy NumPyClient for tests.""" + + def get_properties(self, config: Config) -> Dict[str, Scalar]: + """Return properties by doing a simple calculation.""" + result = float(config["factor"]) * pi + + # store something in context + self.context.state.configs_records["result"] = ConfigsRecord({"result": result}) + return {"result": result} + + +def get_dummy_client(cid: str) -> Client: # pylint: disable=unused-argument + """Return a DummyClient converted to Client type.""" + return DummyClient().to_client() + + +def _load_app() -> ClientApp: + return ClientApp(client_fn=get_dummy_client) + + +client_app = ClientApp( + client_fn=get_dummy_client, +) + + +def _load_from_module(client_app_module_name: str) -> Callable[[], ClientApp]: + def _load_app() -> ClientApp: + app = load_app(client_app_module_name, LoadClientAppError) + + if not isinstance(app, ClientApp): + raise LoadClientAppError( + f"Attribute {client_app_module_name} is not of type {ClientApp}", + ) from None + + return app + + return _load_app + + +async def backend_build_process_and_termination( + backend: RayBackend, + process_args: Optional[Tuple[Callable[[], ClientApp], Message, Context]] = None, +) -> Union[Tuple[Message, Context], None]: + """Build, process job and terminate RayBackend.""" + await backend.build() + to_return = None + + if process_args: + to_return = await backend.process_message(*process_args) + + await backend.terminate() + + return to_return + + +def _create_message_and_context() -> Tuple[Message, Context, float]: + + # Construct a Message + mult_factor = 2024 + getproperties_ins = GetPropertiesIns(config={"factor": mult_factor}) + recordset = getpropertiesins_to_recordset(getproperties_ins) + message = Message( + content=recordset, + metadata=Metadata( + run_id=0, + message_id="", + group_id="", + src_node_id=0, + dst_node_id=0, + reply_to_message="", + ttl="", + message_type=MessageTypeLegacy.GET_PROPERTIES, + ), + ) + + # Construct emtpy Context + context = Context(state=RecordSet()) + + # Expected output + expected_output = pi * mult_factor + + return message, context, expected_output + + +class AsyncTestRayBackend(IsolatedAsyncioTestCase): + """A basic class that allows runnig multliple asyncio tests.""" + + async def on_cleanup(self) -> None: + """Ensure Ray has shutdown.""" + if ray.is_initialized(): + ray.shutdown() + + def test_backend_creation_and_termination(self) -> None: + """Test creation of RayBackend and its termination.""" + backend = RayBackend(backend_config={}, work_dir="") + asyncio.run( + backend_build_process_and_termination(backend=backend, process_args=None) + ) + + def test_backend_creation_submit_and_termination( + self, + client_app_loader: Callable[[], ClientApp] = _load_app, + workdir: str = "", + ) -> None: + """Test submitting a message to a given ClientApp.""" + backend = RayBackend(backend_config={}, work_dir=workdir) + + # Define ClientApp + client_app_callable = client_app_loader + + message, context, expected_output = _create_message_and_context() + + res = asyncio.run( + backend_build_process_and_termination( + backend=backend, process_args=(client_app_callable, message, context) + ) + ) + + if res is None: + raise AssertionError("This shouldn't happen") + + out_mssg, updated_context = res + + # Verify message content is as expected + content = out_mssg.content + assert ( + content.configs_records["getpropertiesres.properties"]["result"] + == expected_output + ) + + # Verify context is correct + obtained_result_in_context = updated_context.state.configs_records["result"][ + "result" + ] + assert obtained_result_in_context == expected_output + + def test_backend_creation_submit_and_termination_non_existing_client_app( + self, + ) -> None: + """Testing with ClientApp module that does not exist.""" + with self.assertRaises(LoadClientAppError): + self.test_backend_creation_submit_and_termination( + client_app_loader=_load_from_module("a_non_existing_module:app") + ) + self.addAsyncCleanup(self.on_cleanup) + + def test_backend_creation_submit_and_termination_existing_client_app( + self, + ) -> None: + """Testing with ClientApp module that exist.""" + # Resolve what should be the workdir to pass upon Backend initialisation + file_path = Path(__file__) + working_dir = Path.cwd() + rel_workdir = file_path.relative_to(working_dir) + + # Susbtract last element + rel_workdir_str = str(rel_workdir.parent) + + self.test_backend_creation_submit_and_termination( + client_app_loader=_load_from_module("raybackend_test:client_app"), + workdir=rel_workdir_str, + ) + + def test_backend_creation_submit_and_termination_existing_client_app_unsetworkdir( + self, + ) -> None: + """Testing with ClientApp module that exist but the passed workdir does not.""" + with self.assertRaises(ValueError): + self.test_backend_creation_submit_and_termination( + client_app_loader=_load_from_module("raybackend_test:client_app"), + workdir="/?&%$^#%@$!", + ) + self.addAsyncCleanup(self.on_cleanup) diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api.py b/src/py/flwr/server/superlink/fleet/vce/vce_api.py new file mode 100644 index 000000000000..a693c968d0e8 --- /dev/null +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api.py @@ -0,0 +1,331 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Fleet Simulation Engine API.""" + + +import asyncio +import json +import traceback +from logging import DEBUG, ERROR, INFO, WARN +from typing import Callable, Dict, List, Optional + +from flwr.client.client_app import ClientApp, LoadClientAppError +from flwr.client.node_state import NodeState +from flwr.common.logger import log +from flwr.common.object_ref import load_app +from flwr.common.serde import message_from_taskins, message_to_taskres +from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611 +from flwr.server.superlink.state import StateFactory + +from .backend import Backend, error_messages_backends, supported_backends + +NodeToPartitionMapping = Dict[int, int] + + +def _register_nodes( + num_nodes: int, state_factory: StateFactory +) -> NodeToPartitionMapping: + """Register nodes with the StateFactory and create node-id:partition-id mapping.""" + nodes_mapping: NodeToPartitionMapping = {} + state = state_factory.state() + for i in range(num_nodes): + node_id = state.create_node() + nodes_mapping[node_id] = i + log(INFO, "Registered %i nodes", len(nodes_mapping)) + return nodes_mapping + + +# pylint: disable=too-many-arguments,too-many-locals +async def worker( + app_fn: Callable[[], ClientApp], + queue: "asyncio.Queue[TaskIns]", + node_states: Dict[int, NodeState], + state_factory: StateFactory, + nodes_mapping: NodeToPartitionMapping, + backend: Backend, +) -> None: + """Get TaskIns from queue and pass it to an actor in the pool to execute it.""" + state = state_factory.state() + while True: + try: + task_ins: TaskIns = await queue.get() + node_id = task_ins.task.consumer.node_id + + # Register and retrieve runstate + node_states[node_id].register_context(run_id=task_ins.run_id) + context = node_states[node_id].retrieve_context(run_id=task_ins.run_id) + + # Convert TaskIns to Message + message = message_from_taskins(task_ins) + # Set partition_id + message.metadata.partition_id = nodes_mapping[node_id] + + # Let backend process message + out_mssg, updated_context = await backend.process_message( + app_fn, message, context + ) + + # Update Context + node_states[node_id].update_context( + task_ins.run_id, context=updated_context + ) + + # Convert to TaskRes + task_res = message_to_taskres(out_mssg) + # Store TaskRes in state + state.store_task_res(task_res) + + except asyncio.CancelledError as e: + log(DEBUG, "Async worker: %s", e) + break + + except LoadClientAppError as app_ex: + log(ERROR, "Async worker: %s", app_ex) + log(ERROR, traceback.format_exc()) + raise + + except Exception as ex: # pylint: disable=broad-exception-caught + log(ERROR, ex) + log(ERROR, traceback.format_exc()) + break + + +async def add_taskins_to_queue( + queue: "asyncio.Queue[TaskIns]", + state_factory: StateFactory, + nodes_mapping: NodeToPartitionMapping, + backend: Backend, + consumers: List["asyncio.Task[None]"], + f_stop: asyncio.Event, +) -> None: + """Retrieve TaskIns and add it to the queue.""" + state = state_factory.state() + num_initial_consumers = len(consumers) + while not f_stop.is_set(): + for node_id in nodes_mapping.keys(): + task_ins = state.get_task_ins(node_id=node_id, limit=1) + if task_ins: + await queue.put(task_ins[0]) + + # Count consumers that are running + num_active = sum(not (cc.done()) for cc in consumers) + + # Alert if number of consumers decreased by half + if num_active < num_initial_consumers // 2: + log( + WARN, + "Number of active workers has more than halved: (%i/%i active)", + num_active, + num_initial_consumers, + ) + + # Break if consumers died + if num_active == 0: + raise RuntimeError("All workers have died. Ending Simulation.") + + # Log some stats + log( + DEBUG, + "Simulation Engine stats: " + "Active workers: (%i/%i) | %s (%i workers) | Tasks in queue: %i)", + num_active, + num_initial_consumers, + backend.__class__.__name__, + backend.num_workers, + queue.qsize(), + ) + await asyncio.sleep(1.0) + log(DEBUG, "Async producer: Stopped pulling from StateFactory.") + + +async def run( + app_fn: Callable[[], ClientApp], + backend_fn: Callable[[], Backend], + nodes_mapping: NodeToPartitionMapping, + state_factory: StateFactory, + node_states: Dict[int, NodeState], + f_stop: asyncio.Event, +) -> None: + """Run the VCE async.""" + queue: "asyncio.Queue[TaskIns]" = asyncio.Queue(128) + + try: + + # Instantiate backend + backend = backend_fn() + + # Build backend + await backend.build() + + # Add workers (they submit Messages to Backend) + worker_tasks = [ + asyncio.create_task( + worker( + app_fn, queue, node_states, state_factory, nodes_mapping, backend + ) + ) + for _ in range(backend.num_workers) + ] + # Create producer (adds TaskIns into Queue) + producer = asyncio.create_task( + add_taskins_to_queue( + queue, state_factory, nodes_mapping, backend, worker_tasks, f_stop + ) + ) + + # Wait for producer to finish + # The producer runs forever until f_stop is set or until + # all worker (consumer) coroutines are completed. Workers + # also run forever and only end if an exception is raised. + await asyncio.gather(producer) + + except Exception as ex: + + log(ERROR, "An exception occured!! %s", ex) + log(ERROR, traceback.format_exc()) + log(WARN, "Stopping Simulation Engine.") + + # Manually trigger stopping event + f_stop.set() + + # Raise exception + raise RuntimeError("Simulation Engine crashed.") from ex + + finally: + # Produced task terminated, now cancel worker tasks + for w_t in worker_tasks: + _ = w_t.cancel() + + while not all(w_t.done() for w_t in worker_tasks): + log(DEBUG, "Terminating async workers...") + await asyncio.sleep(0.5) + + await asyncio.gather(*[w_t for w_t in worker_tasks if not w_t.done()]) + + # Terminate backend + await backend.terminate() + + +# pylint: disable=too-many-arguments,unused-argument,too-many-locals +def start_vce( + backend_name: str, + backend_config_json_stream: str, + app_dir: str, + f_stop: asyncio.Event, + client_app: Optional[ClientApp] = None, + client_app_attr: Optional[str] = None, + num_supernodes: Optional[int] = None, + state_factory: Optional[StateFactory] = None, + existing_nodes_mapping: Optional[NodeToPartitionMapping] = None, +) -> None: + """Start Fleet API with the Simulation Engine.""" + if client_app_attr is not None and client_app is not None: + raise ValueError( + "Both `client_app_attr` and `client_app` are provided, " + "but only one is allowed." + ) + + if num_supernodes is not None and existing_nodes_mapping is not None: + raise ValueError( + "Both `num_supernodes` and `existing_nodes_mapping` are provided, " + "but only one is allowed." + ) + if num_supernodes is None: + if state_factory is None or existing_nodes_mapping is None: + raise ValueError( + "If not passing an existing `state_factory` and associated " + "`existing_nodes_mapping` you must supply `num_supernodes` to indicate " + "how many nodes to insert into a new StateFactory that will be created." + ) + if existing_nodes_mapping: + if state_factory is None: + raise ValueError( + "`existing_nodes_mapping` was passed, but no `state_factory` was " + "passed." + ) + log(INFO, "Using exiting NodeToPartitionMapping and StateFactory.") + # Use mapping constructed externally. This also means nodes + # have previously being registered. + nodes_mapping = existing_nodes_mapping + + if not state_factory: + log(INFO, "A StateFactory was not supplied to the SimulationEngine.") + # Create an empty in-memory state factory + state_factory = StateFactory(":flwr-in-memory-state:") + log(INFO, "Created new %s.", state_factory.__class__.__name__) + + if num_supernodes: + # Register SuperNodes + nodes_mapping = _register_nodes( + num_nodes=num_supernodes, state_factory=state_factory + ) + + # Construct mapping of NodeStates + node_states: Dict[int, NodeState] = {} + for node_id in nodes_mapping: + node_states[node_id] = NodeState() + + # Load backend config + log(INFO, "Supported backends: %s", list(supported_backends.keys())) + backend_config = json.loads(backend_config_json_stream) + + try: + backend_type = supported_backends[backend_name] + except KeyError as ex: + log( + ERROR, + "Backend `%s`, is not supported. Use any of %s or add support " + "for a new backend.", + backend_name, + list(supported_backends.keys()), + ) + if backend_name in error_messages_backends: + log(ERROR, error_messages_backends[backend_name]) + + raise ex + + def backend_fn() -> Backend: + """Instantiate a Backend.""" + return backend_type(backend_config, work_dir=app_dir) + + log(INFO, "client_app_attr = %s", client_app_attr) + + # Load ClientApp if needed + def _load() -> ClientApp: + + if client_app_attr: + app: ClientApp = load_app(client_app_attr, LoadClientAppError) + + if not isinstance(app, ClientApp): + raise LoadClientAppError( + f"Attribute {client_app_attr} is not of type {ClientApp}", + ) from None + + if client_app: + app = client_app + return app + + app_fn = _load + + asyncio.run( + run( + app_fn, + backend_fn, + nodes_mapping, + state_factory, + node_states, + f_stop, + ) + ) diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py new file mode 100644 index 000000000000..8c37399ae295 --- /dev/null +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py @@ -0,0 +1,291 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test Fleet Simulation Engine API.""" + + +import asyncio +import threading +from itertools import cycle +from json import JSONDecodeError +from math import pi +from pathlib import Path +from time import sleep +from typing import Dict, Optional, Set, Tuple +from unittest import IsolatedAsyncioTestCase +from uuid import UUID + +from flwr.common import GetPropertiesIns, Message, MessageTypeLegacy, Metadata +from flwr.common.recordset_compat import getpropertiesins_to_recordset +from flwr.common.serde import message_from_taskres, message_to_taskins +from flwr.server.superlink.fleet.vce.vce_api import ( + NodeToPartitionMapping, + _register_nodes, + start_vce, +) +from flwr.server.superlink.state import InMemoryState, StateFactory + + +def terminate_simulation(f_stop: asyncio.Event, sleep_duration: int) -> None: + """Set event to terminate Simulation Engine after `sleep_duration` seconds.""" + sleep(sleep_duration) + f_stop.set() + + +def init_state_factory_nodes_mapping( + num_nodes: int, + num_messages: int, + erroneous_message: Optional[bool] = False, +) -> Tuple[StateFactory, NodeToPartitionMapping, Dict[UUID, float]]: + """Instatiate StateFactory, register nodes and pre-insert messages in the state.""" + # Register a state and a run_id in it + run_id = 1234 + state_factory = StateFactory(":flwr-in-memory-state:") + + # Register a few nodes + nodes_mapping = _register_nodes(num_nodes=num_nodes, state_factory=state_factory) + + expected_results = register_messages_into_state( + state_factory=state_factory, + nodes_mapping=nodes_mapping, + run_id=run_id, + num_messages=num_messages, + erroneous_message=erroneous_message, + ) + return state_factory, nodes_mapping, expected_results + + +# pylint: disable=too-many-locals +def register_messages_into_state( + state_factory: StateFactory, + nodes_mapping: NodeToPartitionMapping, + run_id: int, + num_messages: int, + erroneous_message: Optional[bool] = False, +) -> Dict[UUID, float]: + """Register `num_messages` into the state factory.""" + state: InMemoryState = state_factory.state() # type: ignore + state.run_ids.add(run_id) + # Artificially add TaskIns to state so they can be processed + # by the Simulation Engine logic + nodes_cycle = cycle(nodes_mapping.keys()) # we have more messages than supernodes + task_ids: Set[UUID] = set() # so we can retrieve them later + expected_results = {} + for i in range(num_messages): + dst_node_id = next(nodes_cycle) + # Construct a Message + mult_factor = 2024 + i + getproperties_ins = GetPropertiesIns(config={"factor": mult_factor}) + recordset = getpropertiesins_to_recordset(getproperties_ins) + message = Message( + content=recordset, + metadata=Metadata( + run_id=run_id, + message_id="", + group_id="", + src_node_id=0, + dst_node_id=dst_node_id, # indicate destination node + reply_to_message="", + ttl="", + message_type=( + "a bad message" + if erroneous_message + else MessageTypeLegacy.GET_PROPERTIES + ), + ), + ) + # Convert Message to TaskIns + taskins = message_to_taskins(message) + # Instert in state + task_id = state.store_task_ins(taskins) + if task_id: + # Add to UUID set + task_ids.add(task_id) + # Store expected output for check later on + expected_results[task_id] = mult_factor * pi + + return expected_results + + +def _autoresolve_app_dir(rel_client_app_dir: str = "backend") -> str: + """Correctly resolve working directory for the app.""" + file_path = Path(__file__) + app_dir = Path.cwd() + rel_app_dir = file_path.relative_to(app_dir) + + # Susbtract lats element and append "backend/test" (wher the client module is.) + return str(rel_app_dir.parent / rel_client_app_dir) + + +# pylint: disable=too-many-arguments +def start_and_shutdown( + backend: str = "ray", + client_app_attr: str = "raybackend_test:client_app", + app_dir: str = "", + num_supernodes: Optional[int] = None, + state_factory: Optional[StateFactory] = None, + nodes_mapping: Optional[NodeToPartitionMapping] = None, + duration: int = 0, + backend_config: str = "{}", +) -> None: + """Start Simulation Engine and terminate after specified number of seconds. + + Some tests need to be terminated by triggering externally an asyncio.Event. This + is enabled whtn passing `duration`>0. + """ + f_stop = asyncio.Event() + + if duration: + + # Setup thread that will set the f_stop event, triggering the termination of all + # asyncio logic in the Simulation Engine. It will also terminate the Backend. + termination_th = threading.Thread( + target=terminate_simulation, args=(f_stop, duration) + ) + termination_th.start() + + # Resolve working directory if not passed + if not app_dir: + app_dir = _autoresolve_app_dir() + + start_vce( + num_supernodes=num_supernodes, + client_app_attr=client_app_attr, + backend_name=backend, + backend_config_json_stream=backend_config, + state_factory=state_factory, + app_dir=app_dir, + f_stop=f_stop, + existing_nodes_mapping=nodes_mapping, + ) + + if duration: + termination_th.join() + + +class AsyncTestFleetSimulationEngineRayBackend(IsolatedAsyncioTestCase): + """A basic class that enables testing asyncio functionalities.""" + + def test_erroneous_no_supernodes_client_mapping(self) -> None: + """Test with unset arguments.""" + with self.assertRaises(ValueError): + start_and_shutdown(duration=2) + + def test_erroneous_client_app_attr(self) -> None: + """Tests attempt to load a ClientApp that can't be found.""" + num_messages = 7 + num_nodes = 59 + + state_factory, nodes_mapping, _ = init_state_factory_nodes_mapping( + num_nodes=num_nodes, num_messages=num_messages + ) + with self.assertRaises(RuntimeError): + start_and_shutdown( + client_app_attr="totally_fictitious_app:client", + state_factory=state_factory, + nodes_mapping=nodes_mapping, + ) + + def test_erroneous_messages(self) -> None: + """Test handling of error in async worker (consumer). + + We register messages which will trigger an error when handling, triggering an + error. + """ + num_messages = 100 + num_nodes = 59 + + state_factory, nodes_mapping, _ = init_state_factory_nodes_mapping( + num_nodes=num_nodes, num_messages=num_messages, erroneous_message=True + ) + + with self.assertRaises(RuntimeError): + start_and_shutdown( + state_factory=state_factory, + nodes_mapping=nodes_mapping, + ) + + def test_erroneous_backend_config(self) -> None: + """Backend Config should be a JSON stream.""" + with self.assertRaises(JSONDecodeError): + start_and_shutdown(num_supernodes=50, backend_config="not a proper config") + + def test_with_nonexistent_backend(self) -> None: + """Test specifying a backend that does not exist.""" + with self.assertRaises(KeyError): + start_and_shutdown(num_supernodes=50, backend="this-backend-does-not-exist") + + def test_erroneous_arguments_num_supernodes_and_existing_mapping(self) -> None: + """Test ValueError if a node mapping is passed but also num_supernodes. + + Passing `num_supernodes` does nothing since we assume that if a node mapping + is supplied, nodes have been registered externally already. Therefore passing + `num_supernodes` might give the impression that that many nodes will be + registered. We don't do that since a mapping already exists. + """ + with self.assertRaises(ValueError): + start_and_shutdown(num_supernodes=50, nodes_mapping={0: 1}) + + def test_erroneous_arguments_existing_mapping_but_no_state_factory(self) -> None: + """Test ValueError if a node mapping is passed but no state. + + Passing a node mapping indicates that (externally) nodes have registered with a + state factory. Therefore, that state factory should be passed too. + """ + with self.assertRaises(ValueError): + start_and_shutdown(nodes_mapping={0: 1}) + + def test_start_and_shutdown(self) -> None: + """Start Simulation Engine Fleet and terminate it.""" + start_and_shutdown(num_supernodes=50, duration=10) + + # pylint: disable=too-many-locals + def test_start_and_shutdown_with_tasks_in_state(self) -> None: + """Run Simulation Engine with some TasksIns in State. + + This test creates a few nodes and submits a few messages that need to be + executed by the Backend. In order for that to happen the asyncio + producer/consumer logic must function. This also severs to evaluate a valid + ClientApp. + """ + num_messages = 229 + num_nodes = 59 + + state_factory, nodes_mapping, expected_results = ( + init_state_factory_nodes_mapping( + num_nodes=num_nodes, num_messages=num_messages + ) + ) + + # Run + start_and_shutdown( + state_factory=state_factory, nodes_mapping=nodes_mapping, duration=10 + ) + + # Get all TaskRes + state = state_factory.state() + task_ids = set(expected_results.keys()) + task_res_list = state.get_task_res(task_ids=task_ids, limit=len(task_ids)) + + # Check results by first converting to Message + for task_res in task_res_list: + + message = message_from_taskres(task_res) + + # Verify message content is as expected + content = message.content + assert ( + content.configs_records["getpropertiesres.properties"]["result"] + == expected_results[UUID(task_res.task.ancestry[0])] + ) diff --git a/src/py/flwr/server/state/__init__.py b/src/py/flwr/server/superlink/state/__init__.py similarity index 100% rename from src/py/flwr/server/state/__init__.py rename to src/py/flwr/server/superlink/state/__init__.py diff --git a/src/py/flwr/server/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py similarity index 69% rename from src/py/flwr/server/state/in_memory_state.py rename to src/py/flwr/server/superlink/state/in_memory_state.py index f21845fcb909..ac1ab158e254 100644 --- a/src/py/flwr/server/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -16,6 +16,7 @@ import os +import threading from datetime import datetime, timedelta from logging import ERROR from typing import Dict, List, Optional, Set @@ -23,7 +24,7 @@ from flwr.common import log, now from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 -from flwr.server.state.state import State +from flwr.server.superlink.state.state import State from flwr.server.utils import validate_task_ins_or_res @@ -35,6 +36,7 @@ def __init__(self) -> None: self.run_ids: Set[int] = set() self.task_ins_store: Dict[UUID, TaskIns] = {} self.task_res_store: Dict[UUID, TaskRes] = {} + self.lock = threading.Lock() def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: """Store one TaskIns.""" @@ -57,7 +59,8 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: task_ins.task_id = str(task_id) task_ins.task.created_at = created_at.isoformat() task_ins.task.ttl = ttl.isoformat() - self.task_ins_store[task_id] = task_ins + with self.lock: + self.task_ins_store[task_id] = task_ins # Return the new task_id return task_id @@ -71,22 +74,23 @@ def get_task_ins( # Find TaskIns for node_id that were not delivered yet task_ins_list: List[TaskIns] = [] - for _, task_ins in self.task_ins_store.items(): - # pylint: disable=too-many-boolean-expressions - if ( - node_id is not None # Not anonymous - and task_ins.task.consumer.anonymous is False - and task_ins.task.consumer.node_id == node_id - and task_ins.task.delivered_at == "" - ) or ( - node_id is None # Anonymous - and task_ins.task.consumer.anonymous is True - and task_ins.task.consumer.node_id == 0 - and task_ins.task.delivered_at == "" - ): - task_ins_list.append(task_ins) - if limit and len(task_ins_list) == limit: - break + with self.lock: + for _, task_ins in self.task_ins_store.items(): + # pylint: disable=too-many-boolean-expressions + if ( + node_id is not None # Not anonymous + and task_ins.task.consumer.anonymous is False + and task_ins.task.consumer.node_id == node_id + and task_ins.task.delivered_at == "" + ) or ( + node_id is None # Anonymous + and task_ins.task.consumer.anonymous is True + and task_ins.task.consumer.node_id == 0 + and task_ins.task.delivered_at == "" + ): + task_ins_list.append(task_ins) + if limit and len(task_ins_list) == limit: + break # Mark all of them as delivered delivered_at = now().isoformat() @@ -118,7 +122,8 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: task_res.task_id = str(task_id) task_res.task.created_at = created_at.isoformat() task_res.task.ttl = ttl.isoformat() - self.task_res_store[task_id] = task_res + with self.lock: + self.task_res_store[task_id] = task_res # Return the new task_id return task_id @@ -128,45 +133,47 @@ def get_task_res(self, task_ids: Set[UUID], limit: Optional[int]) -> List[TaskRe if limit is not None and limit < 1: raise AssertionError("`limit` must be >= 1") - # Find TaskRes that were not delivered yet - task_res_list: List[TaskRes] = [] - for _, task_res in self.task_res_store.items(): - if ( - UUID(task_res.task.ancestry[0]) in task_ids - and task_res.task.delivered_at == "" - ): - task_res_list.append(task_res) - if limit and len(task_res_list) == limit: - break - - # Mark all of them as delivered - delivered_at = now().isoformat() - for task_res in task_res_list: - task_res.task.delivered_at = delivered_at - - # Return TaskRes - return task_res_list + with self.lock: + # Find TaskRes that were not delivered yet + task_res_list: List[TaskRes] = [] + for _, task_res in self.task_res_store.items(): + if ( + UUID(task_res.task.ancestry[0]) in task_ids + and task_res.task.delivered_at == "" + ): + task_res_list.append(task_res) + if limit and len(task_res_list) == limit: + break + + # Mark all of them as delivered + delivered_at = now().isoformat() + for task_res in task_res_list: + task_res.task.delivered_at = delivered_at + + # Return TaskRes + return task_res_list def delete_tasks(self, task_ids: Set[UUID]) -> None: """Delete all delivered TaskIns/TaskRes pairs.""" task_ins_to_be_deleted: Set[UUID] = set() task_res_to_be_deleted: Set[UUID] = set() - for task_ins_id in task_ids: - # Find the task_id of the matching task_res - for task_res_id, task_res in self.task_res_store.items(): - if UUID(task_res.task.ancestry[0]) != task_ins_id: - continue - if task_res.task.delivered_at == "": - continue - - task_ins_to_be_deleted.add(task_ins_id) - task_res_to_be_deleted.add(task_res_id) - - for task_id in task_ins_to_be_deleted: - del self.task_ins_store[task_id] - for task_id in task_res_to_be_deleted: - del self.task_res_store[task_id] + with self.lock: + for task_ins_id in task_ids: + # Find the task_id of the matching task_res + for task_res_id, task_res in self.task_res_store.items(): + if UUID(task_res.task.ancestry[0]) != task_ins_id: + continue + if task_res.task.delivered_at == "": + continue + + task_ins_to_be_deleted.add(task_ins_id) + task_res_to_be_deleted.add(task_res_id) + + for task_id in task_ins_to_be_deleted: + del self.task_ins_store[task_id] + for task_id in task_res_to_be_deleted: + del self.task_res_store[task_id] def num_task_ins(self) -> int: """Calculate the number of task_ins in store. diff --git a/src/py/flwr/server/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py similarity index 95% rename from src/py/flwr/server/state/sqlite_state.py rename to src/py/flwr/server/superlink/state/sqlite_state.py index 538ecb84491f..224c16cdf013 100644 --- a/src/py/flwr/server/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -25,11 +25,8 @@ from flwr.common import log, now from flwr.proto.node_pb2 import Node # pylint: disable=E0611 +from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611 from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 -from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 - ClientMessage, - ServerMessage, -) from flwr.server.utils.validator import validate_task_ins_or_res from .state import State @@ -50,7 +47,7 @@ CREATE TABLE IF NOT EXISTS task_ins( task_id TEXT UNIQUE, group_id TEXT, - run_id INTEGER, + run_id INTEGER, producer_anonymous BOOLEAN, producer_node_id INTEGER, consumer_anonymous BOOLEAN, @@ -59,8 +56,8 @@ delivered_at TEXT, ttl TEXT, ancestry TEXT, - legacy_server_message BLOB, - legacy_client_message BLOB, + task_type TEXT, + recordset BLOB, FOREIGN KEY(run_id) REFERENCES run(run_id) ); """ @@ -70,7 +67,7 @@ CREATE TABLE IF NOT EXISTS task_res( task_id TEXT UNIQUE, group_id TEXT, - run_id INTEGER, + run_id INTEGER, producer_anonymous BOOLEAN, producer_node_id INTEGER, consumer_anonymous BOOLEAN, @@ -79,8 +76,8 @@ delivered_at TEXT, ttl TEXT, ancestry TEXT, - legacy_server_message BLOB, - legacy_client_message BLOB, + task_type TEXT, + recordset BLOB, FOREIGN KEY(run_id) REFERENCES run(run_id) ); """ @@ -549,10 +546,8 @@ def task_ins_to_dict(task_msg: TaskIns) -> Dict[str, Any]: "delivered_at": task_msg.task.delivered_at, "ttl": task_msg.task.ttl, "ancestry": ",".join(task_msg.task.ancestry), - "legacy_server_message": ( - task_msg.task.legacy_server_message.SerializeToString() - ), - "legacy_client_message": None, + "task_type": task_msg.task.task_type, + "recordset": task_msg.task.recordset.SerializeToString(), } return result @@ -571,18 +566,16 @@ def task_res_to_dict(task_msg: TaskRes) -> Dict[str, Any]: "delivered_at": task_msg.task.delivered_at, "ttl": task_msg.task.ttl, "ancestry": ",".join(task_msg.task.ancestry), - "legacy_server_message": None, - "legacy_client_message": ( - task_msg.task.legacy_client_message.SerializeToString() - ), + "task_type": task_msg.task.task_type, + "recordset": task_msg.task.recordset.SerializeToString(), } return result def dict_to_task_ins(task_dict: Dict[str, Any]) -> TaskIns: """Turn task_dict into protobuf message.""" - server_message = ServerMessage() - server_message.ParseFromString(task_dict["legacy_server_message"]) + recordset = RecordSet() + recordset.ParseFromString(task_dict["recordset"]) result = TaskIns( task_id=task_dict["task_id"], @@ -601,7 +594,8 @@ def dict_to_task_ins(task_dict: Dict[str, Any]) -> TaskIns: delivered_at=task_dict["delivered_at"], ttl=task_dict["ttl"], ancestry=task_dict["ancestry"].split(","), - legacy_server_message=server_message, + task_type=task_dict["task_type"], + recordset=recordset, ), ) return result @@ -609,8 +603,8 @@ def dict_to_task_ins(task_dict: Dict[str, Any]) -> TaskIns: def dict_to_task_res(task_dict: Dict[str, Any]) -> TaskRes: """Turn task_dict into protobuf message.""" - client_message = ClientMessage() - client_message.ParseFromString(task_dict["legacy_client_message"]) + recordset = RecordSet() + recordset.ParseFromString(task_dict["recordset"]) result = TaskRes( task_id=task_dict["task_id"], @@ -629,7 +623,8 @@ def dict_to_task_res(task_dict: Dict[str, Any]) -> TaskRes: delivered_at=task_dict["delivered_at"], ttl=task_dict["ttl"], ancestry=task_dict["ancestry"].split(","), - legacy_client_message=client_message, + task_type=task_dict["task_type"], + recordset=recordset, ), ) return result diff --git a/src/py/flwr/server/state/sqlite_state_test.py b/src/py/flwr/server/superlink/state/sqlite_state_test.py similarity index 89% rename from src/py/flwr/server/state/sqlite_state_test.py rename to src/py/flwr/server/superlink/state/sqlite_state_test.py index a3f899386011..9eef71e396e3 100644 --- a/src/py/flwr/server/state/sqlite_state_test.py +++ b/src/py/flwr/server/superlink/state/sqlite_state_test.py @@ -17,8 +17,8 @@ import unittest -from flwr.server.state.sqlite_state import task_ins_to_dict -from flwr.server.state.state_test import create_task_ins +from flwr.server.superlink.state.sqlite_state import task_ins_to_dict +from flwr.server.superlink.state.state_test import create_task_ins class SqliteStateTest(unittest.TestCase): @@ -40,8 +40,8 @@ def test_ins_res_to_dict(self) -> None: "delivered_at", "ttl", "ancestry", - "legacy_server_message", - "legacy_client_message", + "task_type", + "recordset", ] # Execute diff --git a/src/py/flwr/server/state/state.py b/src/py/flwr/server/superlink/state/state.py similarity index 100% rename from src/py/flwr/server/state/state.py rename to src/py/flwr/server/superlink/state/state.py diff --git a/src/py/flwr/server/state/state_factory.py b/src/py/flwr/server/superlink/state/state_factory.py similarity index 100% rename from src/py/flwr/server/state/state_factory.py rename to src/py/flwr/server/superlink/state/state_factory.py diff --git a/src/py/flwr/server/state/state_test.py b/src/py/flwr/server/superlink/state/state_test.py similarity index 97% rename from src/py/flwr/server/state/state_test.py rename to src/py/flwr/server/superlink/state/state_test.py index 7f9094625765..d0470a7ce7f7 100644 --- a/src/py/flwr/server/state/state_test.py +++ b/src/py/flwr/server/superlink/state/state_test.py @@ -23,12 +23,9 @@ from uuid import uuid4 from flwr.proto.node_pb2 import Node # pylint: disable=E0611 +from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611 from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 -from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 - ClientMessage, - ServerMessage, -) -from flwr.server.state import InMemoryState, SqliteState, State +from flwr.server.superlink.state import InMemoryState, SqliteState, State class StateTest(unittest.TestCase): @@ -421,9 +418,8 @@ def create_task_ins( delivered_at=delivered_at, producer=Node(node_id=0, anonymous=True), consumer=consumer, - legacy_server_message=ServerMessage( - reconnect_ins=ServerMessage.ReconnectIns() - ), + task_type="mock", + recordset=RecordSet(parameters={}, metrics={}, configs={}), ), ) return task @@ -444,9 +440,8 @@ def create_task_res( producer=Node(node_id=producer_node_id, anonymous=anonymous), consumer=Node(node_id=0, anonymous=True), ancestry=ancestry, - legacy_client_message=ClientMessage( - disconnect_res=ClientMessage.DisconnectRes() - ), + task_type="mock", + recordset=RecordSet(parameters={}, metrics={}, configs={}), ), ) return task_res diff --git a/src/py/flwr/server/typing.py b/src/py/flwr/server/typing.py new file mode 100644 index 000000000000..01143af74392 --- /dev/null +++ b/src/py/flwr/server/typing.py @@ -0,0 +1,25 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Custom types for Flower servers.""" + + +from typing import Callable + +from flwr.common import Context + +from .driver import Driver + +ServerAppCallable = Callable[[Driver, Context], None] +Workflow = Callable[[Driver, Context], None] diff --git a/src/py/flwr/server/utils/validator.py b/src/py/flwr/server/utils/validator.py index 01dbcf982cce..f9b271beafdc 100644 --- a/src/py/flwr/server/utils/validator.py +++ b/src/py/flwr/server/utils/validator.py @@ -64,21 +64,10 @@ def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> List[str validation_errors.append("non-anonymous consumer MUST provide a `node_id`") # Content check - has_fields = { - "sa": tasks_ins_res.task.HasField("sa"), - "legacy_server_message": tasks_ins_res.task.HasField( - "legacy_server_message" - ), - } - if not (has_fields["sa"] or has_fields["legacy_server_message"]): - err_msg = ", ".join([f"`{field}`" for field in has_fields]) - validation_errors.append( - f"`task` in `TaskIns` must set at least one of fields {{{err_msg}}}" - ) - if has_fields[ - "legacy_server_message" - ] and not tasks_ins_res.task.legacy_server_message.HasField("msg"): - validation_errors.append("`legacy_server_message` does not set field `msg`") + if tasks_ins_res.task.task_type == "": + validation_errors.append("`task_type` MUST be set") + if not tasks_ins_res.task.HasField("recordset"): + validation_errors.append("`recordset` MUST be set") # Ancestors if len(tasks_ins_res.task.ancestry) != 0: @@ -115,21 +104,10 @@ def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> List[str validation_errors.append("non-anonymous consumer MUST provide a `node_id`") # Content check - has_fields = { - "sa": tasks_ins_res.task.HasField("sa"), - "legacy_client_message": tasks_ins_res.task.HasField( - "legacy_client_message" - ), - } - if not (has_fields["sa"] or has_fields["legacy_client_message"]): - err_msg = ", ".join([f"`{field}`" for field in has_fields]) - validation_errors.append( - f"`task` in `TaskRes` must set at least one of fields {{{err_msg}}}" - ) - if has_fields[ - "legacy_client_message" - ] and not tasks_ins_res.task.legacy_client_message.HasField("msg"): - validation_errors.append("`legacy_client_message` does not set field `msg`") + if tasks_ins_res.task.task_type == "": + validation_errors.append("`task_type` MUST be set") + if not tasks_ins_res.task.HasField("recordset"): + validation_errors.append("`recordset` MUST be set") # Ancestors if len(tasks_ins_res.task.ancestry) == 0: diff --git a/src/py/flwr/server/utils/validator_test.py b/src/py/flwr/server/utils/validator_test.py index a93e4fb4d457..8e0849508020 100644 --- a/src/py/flwr/server/utils/validator_test.py +++ b/src/py/flwr/server/utils/validator_test.py @@ -19,16 +19,8 @@ from typing import List, Tuple from flwr.proto.node_pb2 import Node # pylint: disable=E0611 -from flwr.proto.task_pb2 import ( # pylint: disable=E0611 - SecureAggregation, - Task, - TaskIns, - TaskRes, -) -from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 - ClientMessage, - ServerMessage, -) +from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611 +from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 from .validator import validate_task_ins_or_res @@ -45,16 +37,12 @@ def test_task_ins(self) -> None: # Execute & Assert for consumer_node_id, anonymous in valid_ins: - msg = create_task_ins( - consumer_node_id, anonymous, has_legacy_server_message=True - ) + msg = create_task_ins(consumer_node_id, anonymous) val_errors = validate_task_ins_or_res(msg) self.assertFalse(val_errors) for consumer_node_id, anonymous in invalid_ins: - msg = create_task_ins( - consumer_node_id, anonymous, has_legacy_server_message=True - ) + msg = create_task_ins(consumer_node_id, anonymous) val_errors = validate_task_ins_or_res(msg) self.assertTrue(val_errors) @@ -78,61 +66,19 @@ def test_is_valid_task_res(self) -> None: # Execute & Assert for producer_node_id, anonymous, ancestry in valid_res: - msg = create_task_res( - producer_node_id, anonymous, ancestry, has_legacy_client_message=True - ) + msg = create_task_res(producer_node_id, anonymous, ancestry) val_errors = validate_task_ins_or_res(msg) self.assertFalse(val_errors) for producer_node_id, anonymous, ancestry in invalid_res: - msg = create_task_res( - producer_node_id, anonymous, ancestry, has_legacy_client_message=True - ) + msg = create_task_res(producer_node_id, anonymous, ancestry) val_errors = validate_task_ins_or_res(msg) self.assertTrue(val_errors, (producer_node_id, anonymous, ancestry)) - def test_task_ins_secure_aggregation(self) -> None: - """Test is_valid task_ins for Secure Aggregation.""" - # Prepare - # (has_legacy_server_message, has_sa) - valid_ins = [(True, True), (False, True)] - invalid_ins = [(False, False)] - - # Execute & Assert - for has_legacy_server_message, has_sa in valid_ins: - msg = create_task_ins(1, False, has_legacy_server_message, has_sa) - val_errors = validate_task_ins_or_res(msg) - self.assertFalse(val_errors) - - for has_legacy_server_message, has_sa in invalid_ins: - msg = create_task_ins(1, False, has_legacy_server_message, has_sa) - val_errors = validate_task_ins_or_res(msg) - self.assertTrue(val_errors) - - def test_task_res_secure_aggregation(self) -> None: - """Test is_valid task_res for Secure Aggregation.""" - # Prepare - # (has_legacy_server_message, has_sa) - valid_res = [(True, True), (False, True)] - invalid_res = [(False, False)] - - # Execute & Assert - for has_legacy_client_message, has_sa in valid_res: - msg = create_task_res(0, True, ["1"], has_legacy_client_message, has_sa) - val_errors = validate_task_ins_or_res(msg) - self.assertFalse(val_errors) - - for has_legacy_client_message, has_sa in invalid_res: - msg = create_task_res(0, True, ["1"], has_legacy_client_message, has_sa) - val_errors = validate_task_ins_or_res(msg) - self.assertTrue(val_errors) - def create_task_ins( consumer_node_id: int, anonymous: bool, - has_legacy_server_message: bool = False, - has_sa: bool = False, delivered_at: str = "", ) -> TaskIns: """Create a TaskIns for testing.""" @@ -148,12 +94,8 @@ def create_task_ins( delivered_at=delivered_at, producer=Node(node_id=0, anonymous=True), consumer=consumer, - legacy_server_message=ServerMessage( - reconnect_ins=ServerMessage.ReconnectIns() - ) - if has_legacy_server_message - else None, - sa=SecureAggregation(named_values={}) if has_sa else None, + task_type="mock", + recordset=RecordSet(parameters={}, metrics={}, configs={}), ), ) return task @@ -163,8 +105,6 @@ def create_task_res( producer_node_id: int, anonymous: bool, ancestry: List[str], - has_legacy_client_message: bool = False, - has_sa: bool = False, ) -> TaskRes: """Create a TaskRes for testing.""" task_res = TaskRes( @@ -175,12 +115,8 @@ def create_task_res( producer=Node(node_id=producer_node_id, anonymous=anonymous), consumer=Node(node_id=0, anonymous=True), ancestry=ancestry, - legacy_client_message=ClientMessage( - disconnect_res=ClientMessage.DisconnectRes() - ) - if has_legacy_client_message - else None, - sa=SecureAggregation(named_values={}) if has_sa else None, + task_type="mock", + recordset=RecordSet(parameters={}, metrics={}, configs={}), ), ) return task_res diff --git a/src/py/flwr/server/workflow/__init__.py b/src/py/flwr/server/workflow/__init__.py new file mode 100644 index 000000000000..31dee89a185d --- /dev/null +++ b/src/py/flwr/server/workflow/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Workflows.""" + + +from .default_workflows import DefaultWorkflow +from .secure_aggregation import SecAggPlusWorkflow, SecAggWorkflow + +__all__ = [ + "DefaultWorkflow", + "SecAggPlusWorkflow", + "SecAggWorkflow", +] diff --git a/src/py/flwr/server/workflow/constant.py b/src/py/flwr/server/workflow/constant.py new file mode 100644 index 000000000000..068e05b27e12 --- /dev/null +++ b/src/py/flwr/server/workflow/constant.py @@ -0,0 +1,32 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Constants for default workflows.""" + + +from __future__ import annotations + +MAIN_CONFIGS_RECORD = "config" +MAIN_PARAMS_RECORD = "parameters" + + +class Key: + """Constants for default workflows.""" + + CURRENT_ROUND = "current_round" + START_TIME = "start_time" + + def __new__(cls) -> Key: + """Prevent instantiation.""" + raise TypeError(f"{cls.__name__} cannot be instantiated.") diff --git a/src/py/flwr/server/workflow/default_workflows.py b/src/py/flwr/server/workflow/default_workflows.py new file mode 100644 index 000000000000..fad85d8eecf8 --- /dev/null +++ b/src/py/flwr/server/workflow/default_workflows.py @@ -0,0 +1,348 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Legacy default workflows.""" + + +import io +import timeit +from logging import INFO +from typing import Optional, cast + +import flwr.common.recordset_compat as compat +from flwr.common import ConfigsRecord, Context, GetParametersIns, log +from flwr.common.constant import MessageType, MessageTypeLegacy + +from ..compat.app_utils import start_update_client_manager_thread +from ..compat.legacy_context import LegacyContext +from ..driver import Driver +from ..typing import Workflow +from .constant import MAIN_CONFIGS_RECORD, MAIN_PARAMS_RECORD, Key + + +class DefaultWorkflow: + """Default workflow in Flower.""" + + def __init__( + self, + fit_workflow: Optional[Workflow] = None, + evaluate_workflow: Optional[Workflow] = None, + ) -> None: + if fit_workflow is None: + fit_workflow = default_fit_workflow + if evaluate_workflow is None: + evaluate_workflow = default_evaluate_workflow + self.fit_workflow: Workflow = fit_workflow + self.evaluate_workflow: Workflow = evaluate_workflow + + def __call__(self, driver: Driver, context: Context) -> None: + """Execute the workflow.""" + if not isinstance(context, LegacyContext): + raise TypeError( + f"Expect a LegacyContext, but get {type(context).__name__}." + ) + + # Start the thread updating nodes + thread, f_stop = start_update_client_manager_thread( + driver, context.client_manager + ) + + # Initialize parameters + log(INFO, "[INIT]") + default_init_params_workflow(driver, context) + + # Run federated learning for num_rounds + start_time = timeit.default_timer() + cfg = ConfigsRecord() + cfg[Key.START_TIME] = start_time + context.state.configs_records[MAIN_CONFIGS_RECORD] = cfg + + for current_round in range(1, context.config.num_rounds + 1): + log(INFO, "") + log(INFO, "[ROUND %s]", current_round) + cfg[Key.CURRENT_ROUND] = current_round + + # Fit round + self.fit_workflow(driver, context) + + # Centralized evaluation + default_centralized_evaluation_workflow(driver, context) + + # Evaluate round + self.evaluate_workflow(driver, context) + + # Bookkeeping and log results + end_time = timeit.default_timer() + elapsed = end_time - start_time + hist = context.history + log(INFO, "") + log(INFO, "[SUMMARY]") + log(INFO, "Run finished %s rounds in %.2fs", context.config.num_rounds, elapsed) + for idx, line in enumerate(io.StringIO(str(hist))): + if idx == 0: + log(INFO, "%s", line.strip("\n")) + else: + log(INFO, "\t%s", line.strip("\n")) + log(INFO, "") + + # Terminate the thread + f_stop.set() + del driver + thread.join() + + +def default_init_params_workflow(driver: Driver, context: Context) -> None: + """Execute the default workflow for parameters initialization.""" + if not isinstance(context, LegacyContext): + raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.") + + parameters = context.strategy.initialize_parameters( + client_manager=context.client_manager + ) + if parameters is not None: + log(INFO, "Using initial global parameters provided by strategy") + paramsrecord = compat.parameters_to_parametersrecord( + parameters, keep_input=True + ) + else: + # Get initial parameters from one of the clients + log(INFO, "Requesting initial parameters from one random client") + random_client = context.client_manager.sample(1)[0] + # Send GetParametersIns and get the response + content = compat.getparametersins_to_recordset(GetParametersIns({})) + messages = driver.send_and_receive( + [ + driver.create_message( + content=content, + message_type=MessageTypeLegacy.GET_PARAMETERS, + dst_node_id=random_client.node_id, + group_id="", + ttl="", + ) + ] + ) + log(INFO, "Received initial parameters from one random client") + msg = list(messages)[0] + paramsrecord = next(iter(msg.content.parameters_records.values())) + + context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord + + # Evaluate initial parameters + log(INFO, "Evaluating initial global parameters") + parameters = compat.parametersrecord_to_parameters(paramsrecord, keep_input=True) + res = context.strategy.evaluate(0, parameters=parameters) + if res is not None: + log( + INFO, + "initial parameters (loss, other metrics): %s, %s", + res[0], + res[1], + ) + context.history.add_loss_centralized(server_round=0, loss=res[0]) + context.history.add_metrics_centralized(server_round=0, metrics=res[1]) + + +def default_centralized_evaluation_workflow(_: Driver, context: Context) -> None: + """Execute the default workflow for centralized evaluation.""" + if not isinstance(context, LegacyContext): + raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.") + + # Retrieve current_round and start_time from the context + cfg = context.state.configs_records[MAIN_CONFIGS_RECORD] + current_round = cast(int, cfg[Key.CURRENT_ROUND]) + start_time = cast(float, cfg[Key.START_TIME]) + + # Centralized evaluation + parameters = compat.parametersrecord_to_parameters( + record=context.state.parameters_records[MAIN_PARAMS_RECORD], + keep_input=True, + ) + res_cen = context.strategy.evaluate(current_round, parameters=parameters) + if res_cen is not None: + loss_cen, metrics_cen = res_cen + log( + INFO, + "fit progress: (%s, %s, %s, %s)", + current_round, + loss_cen, + metrics_cen, + timeit.default_timer() - start_time, + ) + context.history.add_loss_centralized(server_round=current_round, loss=loss_cen) + context.history.add_metrics_centralized( + server_round=current_round, metrics=metrics_cen + ) + + +def default_fit_workflow( # pylint: disable=R0914 + driver: Driver, context: Context +) -> None: + """Execute the default workflow for a single fit round.""" + if not isinstance(context, LegacyContext): + raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.") + + # Get current_round and parameters + cfg = context.state.configs_records[MAIN_CONFIGS_RECORD] + current_round = cast(int, cfg[Key.CURRENT_ROUND]) + parametersrecord = context.state.parameters_records[MAIN_PARAMS_RECORD] + parameters = compat.parametersrecord_to_parameters( + parametersrecord, keep_input=True + ) + + # Get clients and their respective instructions from strategy + client_instructions = context.strategy.configure_fit( + server_round=current_round, + parameters=parameters, + client_manager=context.client_manager, + ) + + if not client_instructions: + log(INFO, "configure_fit: no clients selected, cancel") + return + log( + INFO, + "configure_fit: strategy sampled %s clients (out of %s)", + len(client_instructions), + context.client_manager.num_available(), + ) + + # Build dictionary mapping node_id to ClientProxy + node_id_to_proxy = {proxy.node_id: proxy for proxy, _ in client_instructions} + + # Build out messages + out_messages = [ + driver.create_message( + content=compat.fitins_to_recordset(fitins, True), + message_type=MessageType.TRAIN, + dst_node_id=proxy.node_id, + group_id="", + ttl="", + ) + for proxy, fitins in client_instructions + ] + + # Send instructions to clients and + # collect `fit` results from all clients participating in this round + messages = list(driver.send_and_receive(out_messages)) + del out_messages + num_failures = len([msg for msg in messages if msg.has_error()]) + + # No exception/failure handling currently + log( + INFO, + "aggregate_fit: received %s results and %s failures", + len(messages) - num_failures, + num_failures, + ) + + # Aggregate training results + results = [ + ( + node_id_to_proxy[msg.metadata.src_node_id], + compat.recordset_to_fitres(msg.content, False), + ) + for msg in messages + ] + aggregated_result = context.strategy.aggregate_fit(current_round, results, []) + parameters_aggregated, metrics_aggregated = aggregated_result + + # Update the parameters and write history + if parameters_aggregated: + paramsrecord = compat.parameters_to_parametersrecord( + parameters_aggregated, True + ) + context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord + context.history.add_metrics_distributed_fit( + server_round=current_round, metrics=metrics_aggregated + ) + + +def default_evaluate_workflow(driver: Driver, context: Context) -> None: + """Execute the default workflow for a single evaluate round.""" + if not isinstance(context, LegacyContext): + raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.") + + # Get current_round and parameters + cfg = context.state.configs_records[MAIN_CONFIGS_RECORD] + current_round = cast(int, cfg[Key.CURRENT_ROUND]) + parametersrecord = context.state.parameters_records[MAIN_PARAMS_RECORD] + parameters = compat.parametersrecord_to_parameters( + parametersrecord, keep_input=True + ) + + # Get clients and their respective instructions from strategy + client_instructions = context.strategy.configure_evaluate( + server_round=current_round, + parameters=parameters, + client_manager=context.client_manager, + ) + if not client_instructions: + log(INFO, "configure_evaluate: no clients selected, skipping evaluation") + return + log( + INFO, + "configure_evaluate: strategy sampled %s clients (out of %s)", + len(client_instructions), + context.client_manager.num_available(), + ) + + # Build dictionary mapping node_id to ClientProxy + node_id_to_proxy = {proxy.node_id: proxy for proxy, _ in client_instructions} + + # Build out messages + out_messages = [ + driver.create_message( + content=compat.evaluateins_to_recordset(evalins, True), + message_type=MessageType.EVALUATE, + dst_node_id=proxy.node_id, + group_id="", + ttl="", + ) + for proxy, evalins in client_instructions + ] + + # Send instructions to clients and + # collect `evaluate` results from all clients participating in this round + messages = list(driver.send_and_receive(out_messages)) + del out_messages + num_failures = len([msg for msg in messages if msg.has_error()]) + + # No exception/failure handling currently + log( + INFO, + "aggregate_evaluate: received %s results and %s failures", + len(messages) - num_failures, + num_failures, + ) + + # Aggregate the evaluation results + results = [ + ( + node_id_to_proxy[msg.metadata.src_node_id], + compat.recordset_to_evaluateres(msg.content), + ) + for msg in messages + ] + aggregated_result = context.strategy.aggregate_evaluate(current_round, results, []) + + loss_aggregated, metrics_aggregated = aggregated_result + + # Write history + if loss_aggregated is not None: + context.history.add_loss_distributed( + server_round=current_round, loss=loss_aggregated + ) + context.history.add_metrics_distributed( + server_round=current_round, metrics=metrics_aggregated + ) diff --git a/src/py/flwr/client/secure_aggregation/__init__.py b/src/py/flwr/server/workflow/secure_aggregation/__init__.py similarity index 72% rename from src/py/flwr/client/secure_aggregation/__init__.py rename to src/py/flwr/server/workflow/secure_aggregation/__init__.py index 37c816a390de..25e2a32da334 100644 --- a/src/py/flwr/client/secure_aggregation/__init__.py +++ b/src/py/flwr/server/workflow/secure_aggregation/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Secure Aggregation handlers.""" +"""Secure Aggregation workflows.""" -from .handler import SecureAggregationHandler -from .secaggplus_handler import SecAggPlusHandler +from .secagg_workflow import SecAggWorkflow +from .secaggplus_workflow import SecAggPlusWorkflow __all__ = [ - "SecAggPlusHandler", - "SecureAggregationHandler", + "SecAggPlusWorkflow", + "SecAggWorkflow", ] diff --git a/src/py/flwr/server/workflow/secure_aggregation/secagg_workflow.py b/src/py/flwr/server/workflow/secure_aggregation/secagg_workflow.py new file mode 100644 index 000000000000..f56423e4a0d0 --- /dev/null +++ b/src/py/flwr/server/workflow/secure_aggregation/secagg_workflow.py @@ -0,0 +1,112 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Workflow for the SecAgg protocol.""" + + +from typing import Optional, Union + +from .secaggplus_workflow import SecAggPlusWorkflow + + +class SecAggWorkflow(SecAggPlusWorkflow): + """The workflow for the SecAgg protocol. + + The SecAgg protocol ensures the secure summation of integer vectors owned by + multiple parties, without accessing any individual integer vector. This workflow + allows the server to compute the weighted average of model parameters across all + clients, ensuring individual contributions remain private. This is achieved by + clients sending both, a weighting factor and a weighted version of the locally + updated parameters, both of which are masked for privacy. Specifically, each + client uploads "[w, w * params]" with masks, where weighting factor 'w' is the + number of examples ('num_examples') and 'params' represents the model parameters + ('parameters') from the client's `FitRes`. The server then aggregates these + contributions to compute the weighted average of model parameters. + + The protocol involves four main stages: + - 'setup': Send SecAgg configuration to clients and collect their public keys. + - 'share keys': Broadcast public keys among clients and collect encrypted secret + key shares. + - 'collect masked vectors': Forward encrypted secret key shares to target clients + and collect masked model parameters. + - 'unmask': Collect secret key shares to decrypt and aggregate the model parameters. + + Only the aggregated model parameters are exposed and passed to + `Strategy.aggregate_fit`, ensuring individual data privacy. + + Parameters + ---------- + reconstruction_threshold : Union[int, float] + The minimum number of shares required to reconstruct a client's private key, + or, if specified as a float, it represents the proportion of the total number + of shares needed for reconstruction. This threshold ensures privacy by allowing + for the recovery of contributions from dropped clients during aggregation, + without compromising individual client data. + max_weight : Optional[float] (default: 1000.0) + The maximum value of the weight that can be assigned to any single client's + update during the weighted average calculation on the server side, e.g., in the + FedAvg algorithm. + clipping_range : float, optional (default: 8.0) + The range within which model parameters are clipped before quantization. + This parameter ensures each model parameter is bounded within + [-clipping_range, clipping_range], facilitating quantization. + quantization_range : int, optional (default: 4194304, this equals 2**22) + The size of the range into which floating-point model parameters are quantized, + mapping each parameter to an integer in [0, quantization_range-1]. This + facilitates cryptographic operations on the model updates. + modulus_range : int, optional (default: 4294967296, this equals 2**32) + The range of values from which random mask entries are uniformly sampled + ([0, modulus_range-1]). `modulus_range` must be less than 4294967296. + Please use 2**n values for `modulus_range` to prevent overflow issues. + timeout : Optional[float] (default: None) + The timeout duration in seconds. If specified, the workflow will wait for + replies for this duration each time. If `None`, there is no time limit and + the workflow will wait until replies for all messages are received. + + Notes + ----- + - Each client's private key is split into N shares under the SecAgg protocol, where + N is the number of selected clients. + - Generally, higher `reconstruction_threshold` means better privacy guarantees but + less tolerance to dropouts. + - Too large `max_weight` may compromise the precision of the quantization. + - `modulus_range` must be 2**n and larger than `quantization_range`. + - When `reconstruction_threshold` is a float, it is interpreted as the proportion of + the number of all selected clients needed for the reconstruction of a private key. + This feature enables flexibility in setting the security threshold relative to the + number of selected clients. + - `reconstruction_threshold`, and the quantization parameters + (`clipping_range`, `quantization_range`, `modulus_range`) play critical roles in + balancing privacy, robustness, and efficiency within the SecAgg protocol. + """ + + def __init__( # pylint: disable=R0913 + self, + reconstruction_threshold: Union[int, float], + *, + max_weight: float = 1000.0, + clipping_range: float = 8.0, + quantization_range: int = 4194304, + modulus_range: int = 4294967296, + timeout: Optional[float] = None, + ) -> None: + super().__init__( + num_shares=1.0, + reconstruction_threshold=reconstruction_threshold, + max_weight=max_weight, + clipping_range=clipping_range, + quantization_range=quantization_range, + modulus_range=modulus_range, + timeout=timeout, + ) diff --git a/src/py/flwr/server/workflow/secure_aggregation/secaggplus_workflow.py b/src/py/flwr/server/workflow/secure_aggregation/secaggplus_workflow.py new file mode 100644 index 000000000000..559dc1cf8739 --- /dev/null +++ b/src/py/flwr/server/workflow/secure_aggregation/secaggplus_workflow.py @@ -0,0 +1,676 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Workflow for the SecAgg+ protocol.""" + + +import random +from dataclasses import dataclass, field +from logging import DEBUG, ERROR, INFO, WARN +from typing import Dict, List, Optional, Set, Union, cast + +import flwr.common.recordset_compat as compat +from flwr.common import ( + Code, + ConfigsRecord, + Context, + FitRes, + Message, + MessageType, + NDArrays, + RecordSet, + Status, + bytes_to_ndarray, + log, + ndarrays_to_parameters, +) +from flwr.common.secure_aggregation.crypto.shamir import combine_shares +from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( + bytes_to_private_key, + bytes_to_public_key, + generate_shared_key, +) +from flwr.common.secure_aggregation.ndarrays_arithmetic import ( + factor_extract, + get_parameters_shape, + parameters_addition, + parameters_mod, + parameters_subtraction, +) +from flwr.common.secure_aggregation.quantization import dequantize +from flwr.common.secure_aggregation.secaggplus_constants import ( + RECORD_KEY_CONFIGS, + Key, + Stage, +) +from flwr.common.secure_aggregation.secaggplus_utils import pseudo_rand_gen +from flwr.server.compat.driver_client_proxy import DriverClientProxy +from flwr.server.compat.legacy_context import LegacyContext +from flwr.server.driver import Driver + +from ..constant import MAIN_CONFIGS_RECORD, MAIN_PARAMS_RECORD +from ..constant import Key as WorkflowKey + + +@dataclass +class WorkflowState: # pylint: disable=R0902 + """The state of the SecAgg+ protocol.""" + + nid_to_fitins: Dict[int, RecordSet] = field(default_factory=dict) + sampled_node_ids: Set[int] = field(default_factory=set) + active_node_ids: Set[int] = field(default_factory=set) + num_shares: int = 0 + threshold: int = 0 + clipping_range: float = 0.0 + quantization_range: int = 0 + mod_range: int = 0 + max_weight: float = 0.0 + nid_to_neighbours: Dict[int, Set[int]] = field(default_factory=dict) + nid_to_publickeys: Dict[int, List[bytes]] = field(default_factory=dict) + forward_srcs: Dict[int, List[int]] = field(default_factory=dict) + forward_ciphertexts: Dict[int, List[bytes]] = field(default_factory=dict) + aggregate_ndarrays: NDArrays = field(default_factory=list) + + +class SecAggPlusWorkflow: + """The workflow for the SecAgg+ protocol. + + The SecAgg+ protocol ensures the secure summation of integer vectors owned by + multiple parties, without accessing any individual integer vector. This workflow + allows the server to compute the weighted average of model parameters across all + clients, ensuring individual contributions remain private. This is achieved by + clients sending both, a weighting factor and a weighted version of the locally + updated parameters, both of which are masked for privacy. Specifically, each + client uploads "[w, w * params]" with masks, where weighting factor 'w' is the + number of examples ('num_examples') and 'params' represents the model parameters + ('parameters') from the client's `FitRes`. The server then aggregates these + contributions to compute the weighted average of model parameters. + + The protocol involves four main stages: + - 'setup': Send SecAgg+ configuration to clients and collect their public keys. + - 'share keys': Broadcast public keys among clients and collect encrypted secret + key shares. + - 'collect masked vectors': Forward encrypted secret key shares to target clients + and collect masked model parameters. + - 'unmask': Collect secret key shares to decrypt and aggregate the model parameters. + + Only the aggregated model parameters are exposed and passed to + `Strategy.aggregate_fit`, ensuring individual data privacy. + + Parameters + ---------- + num_shares : Union[int, float] + The number of shares into which each client's private key is split under + the SecAgg+ protocol. If specified as a float, it represents the proportion + of all selected clients, and the number of shares will be set dynamically in + the run time. A private key can be reconstructed from these shares, allowing + for the secure aggregation of model updates. Each client sends one share to + each of its neighbors while retaining one. + reconstruction_threshold : Union[int, float] + The minimum number of shares required to reconstruct a client's private key, + or, if specified as a float, it represents the proportion of the total number + of shares needed for reconstruction. This threshold ensures privacy by allowing + for the recovery of contributions from dropped clients during aggregation, + without compromising individual client data. + max_weight : Optional[float] (default: 1000.0) + The maximum value of the weight that can be assigned to any single client's + update during the weighted average calculation on the server side, e.g., in the + FedAvg algorithm. + clipping_range : float, optional (default: 8.0) + The range within which model parameters are clipped before quantization. + This parameter ensures each model parameter is bounded within + [-clipping_range, clipping_range], facilitating quantization. + quantization_range : int, optional (default: 4194304, this equals 2**22) + The size of the range into which floating-point model parameters are quantized, + mapping each parameter to an integer in [0, quantization_range-1]. This + facilitates cryptographic operations on the model updates. + modulus_range : int, optional (default: 4294967296, this equals 2**32) + The range of values from which random mask entries are uniformly sampled + ([0, modulus_range-1]). `modulus_range` must be less than 4294967296. + Please use 2**n values for `modulus_range` to prevent overflow issues. + timeout : Optional[float] (default: None) + The timeout duration in seconds. If specified, the workflow will wait for + replies for this duration each time. If `None`, there is no time limit and + the workflow will wait until replies for all messages are received. + + Notes + ----- + - Generally, higher `num_shares` means more robust to dropouts while increasing the + computational costs; higher `reconstruction_threshold` means better privacy + guarantees but less tolerance to dropouts. + - Too large `max_weight` may compromise the precision of the quantization. + - `modulus_range` must be 2**n and larger than `quantization_range`. + - When `num_shares` is a float, it is interpreted as the proportion of all selected + clients, and hence the number of shares will be determined in the runtime. This + allows for dynamic adjustment based on the total number of participating clients. + - Similarly, when `reconstruction_threshold` is a float, it is interpreted as the + proportion of the number of shares needed for the reconstruction of a private key. + This feature enables flexibility in setting the security threshold relative to the + number of distributed shares. + - `num_shares`, `reconstruction_threshold`, and the quantization parameters + (`clipping_range`, `quantization_range`, `modulus_range`) play critical roles in + balancing privacy, robustness, and efficiency within the SecAgg+ protocol. + """ + + def __init__( # pylint: disable=R0913 + self, + num_shares: Union[int, float], + reconstruction_threshold: Union[int, float], + *, + max_weight: float = 1000.0, + clipping_range: float = 8.0, + quantization_range: int = 4194304, + modulus_range: int = 4294967296, + timeout: Optional[float] = None, + ) -> None: + self.num_shares = num_shares + self.reconstruction_threshold = reconstruction_threshold + self.max_weight = max_weight + self.clipping_range = clipping_range + self.quantization_range = quantization_range + self.modulus_range = modulus_range + self.timeout = timeout + + self._check_init_params() + + def __call__(self, driver: Driver, context: Context) -> None: + """Run the SecAgg+ protocol.""" + if not isinstance(context, LegacyContext): + raise TypeError( + f"Expect a LegacyContext, but get {type(context).__name__}." + ) + state = WorkflowState() + + steps = ( + self.setup_stage, + self.share_keys_stage, + self.collect_masked_vectors_stage, + self.unmask_stage, + ) + log(INFO, "Secure aggregation commencing.") + for step in steps: + if not step(driver, context, state): + log(INFO, "Secure aggregation halted.") + return + log(INFO, "Secure aggregation completed.") + + def _check_init_params(self) -> None: # pylint: disable=R0912 + # Check `num_shares` + if not isinstance(self.num_shares, (int, float)): + raise TypeError("`num_shares` must be of type int or float.") + if isinstance(self.num_shares, int): + if self.num_shares == 1: + self.num_shares = 1.0 + elif self.num_shares <= 2: + raise ValueError("`num_shares` as an integer must be greater than 2.") + elif self.num_shares > self.modulus_range / self.quantization_range: + log( + WARN, + "A `num_shares` larger than `modulus_range / quantization_range` " + "will potentially cause overflow when computing the aggregated " + "model parameters.", + ) + elif self.num_shares <= 0: + raise ValueError("`num_shares` as a float must be greater than 0.") + + # Check `reconstruction_threshold` + if not isinstance(self.reconstruction_threshold, (int, float)): + raise TypeError("`reconstruction_threshold` must be of type int or float.") + if isinstance(self.reconstruction_threshold, int): + if self.reconstruction_threshold == 1: + self.reconstruction_threshold = 1.0 + elif isinstance(self.num_shares, int): + if self.reconstruction_threshold >= self.num_shares: + raise ValueError( + "`reconstruction_threshold` must be less than `num_shares`." + ) + else: + if not 0 < self.reconstruction_threshold <= 1: + raise ValueError( + "If `reconstruction_threshold` is a float, " + "it must be greater than 0 and less than or equal to 1." + ) + + # Check `max_weight` + if self.max_weight <= 0: + raise ValueError("`max_weight` must be greater than 0.") + + # Check `quantization_range` + if self.quantization_range <= 0: + raise ValueError("`quantization_range` must be greater than 0.") + + # Check `quantization_range` + if not isinstance(self.quantization_range, int) or self.quantization_range <= 0: + raise ValueError( + "`quantization_range` must be an integer and greater than 0." + ) + + # Check `modulus_range` + if ( + not isinstance(self.modulus_range, int) + or self.modulus_range <= self.quantization_range + ): + raise ValueError( + "`modulus_range` must be an integer and " + "greater than `quantization_range`." + ) + if bin(self.modulus_range).count("1") != 1: + raise ValueError("`modulus_range` must be a power of 2.") + + def _check_threshold(self, state: WorkflowState) -> bool: + for node_id in state.sampled_node_ids: + active_neighbors = state.nid_to_neighbours[node_id] & state.active_node_ids + if len(active_neighbors) < state.threshold: + log(ERROR, "Insufficient available nodes.") + return False + return True + + def setup_stage( # pylint: disable=R0912, R0914, R0915 + self, driver: Driver, context: LegacyContext, state: WorkflowState + ) -> bool: + """Execute the 'setup' stage.""" + # Obtain fit instructions + cfg = context.state.configs_records[MAIN_CONFIGS_RECORD] + current_round = cast(int, cfg[WorkflowKey.CURRENT_ROUND]) + parameters = compat.parametersrecord_to_parameters( + context.state.parameters_records[MAIN_PARAMS_RECORD], + keep_input=True, + ) + proxy_fitins_lst = context.strategy.configure_fit( + current_round, parameters, context.client_manager + ) + if not proxy_fitins_lst: + log(INFO, "configure_fit: no clients selected, cancel") + return False + log( + INFO, + "configure_fit: strategy sampled %s clients (out of %s)", + len(proxy_fitins_lst), + context.client_manager.num_available(), + ) + + state.nid_to_fitins = { + proxy.node_id: compat.fitins_to_recordset(fitins, False) + for proxy, fitins in proxy_fitins_lst + } + + # Protocol config + sampled_node_ids = list(state.nid_to_fitins.keys()) + num_samples = len(sampled_node_ids) + if num_samples < 2: + log(ERROR, "The number of samples should be greater than 1.") + return False + if isinstance(self.num_shares, float): + state.num_shares = round(self.num_shares * num_samples) + # If even + if state.num_shares < num_samples and state.num_shares & 1 == 0: + state.num_shares += 1 + # If too small + if state.num_shares <= 2: + state.num_shares = num_samples + else: + state.num_shares = self.num_shares + if isinstance(self.reconstruction_threshold, float): + state.threshold = round(self.reconstruction_threshold * state.num_shares) + # Avoid too small threshold + state.threshold = max(state.threshold, 2) + else: + state.threshold = self.reconstruction_threshold + state.active_node_ids = set(sampled_node_ids) + state.clipping_range = self.clipping_range + state.quantization_range = self.quantization_range + state.mod_range = self.modulus_range + state.max_weight = self.max_weight + sa_params_dict = { + Key.STAGE: Stage.SETUP, + Key.SAMPLE_NUMBER: num_samples, + Key.SHARE_NUMBER: state.num_shares, + Key.THRESHOLD: state.threshold, + Key.CLIPPING_RANGE: state.clipping_range, + Key.TARGET_RANGE: state.quantization_range, + Key.MOD_RANGE: state.mod_range, + Key.MAX_WEIGHT: state.max_weight, + } + + # The number of shares should better be odd in the SecAgg+ protocol. + if num_samples != state.num_shares and state.num_shares & 1 == 0: + log(WARN, "Number of shares in the SecAgg+ protocol should be odd.") + state.num_shares += 1 + + # Shuffle node IDs + random.shuffle(sampled_node_ids) + # Build neighbour relations (node ID -> secure IDs of neighbours) + half_share = state.num_shares >> 1 + state.nid_to_neighbours = { + nid: { + sampled_node_ids[(idx + offset) % num_samples] + for offset in range(-half_share, half_share + 1) + } + for idx, nid in enumerate(sampled_node_ids) + } + + state.sampled_node_ids = state.active_node_ids + + # Send setup configuration to clients + cfgs_record = ConfigsRecord(sa_params_dict) # type: ignore + content = RecordSet(configs_records={RECORD_KEY_CONFIGS: cfgs_record}) + + def make(nid: int) -> Message: + return driver.create_message( + content=content, + message_type=MessageType.TRAIN, + dst_node_id=nid, + group_id=str(cfg[WorkflowKey.CURRENT_ROUND]), + ttl="", + ) + + log( + DEBUG, + "[Stage 0] Sending configurations to %s clients.", + len(state.active_node_ids), + ) + msgs = driver.send_and_receive( + [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout + ) + state.active_node_ids = { + msg.metadata.src_node_id for msg in msgs if not msg.has_error() + } + log( + DEBUG, + "[Stage 0] Received public keys from %s clients.", + len(state.active_node_ids), + ) + + for msg in msgs: + if msg.has_error(): + continue + key_dict = msg.content.configs_records[RECORD_KEY_CONFIGS] + node_id = msg.metadata.src_node_id + pk1, pk2 = key_dict[Key.PUBLIC_KEY_1], key_dict[Key.PUBLIC_KEY_2] + state.nid_to_publickeys[node_id] = [cast(bytes, pk1), cast(bytes, pk2)] + + return self._check_threshold(state) + + def share_keys_stage( # pylint: disable=R0914 + self, driver: Driver, context: LegacyContext, state: WorkflowState + ) -> bool: + """Execute the 'share keys' stage.""" + cfg = context.state.configs_records[MAIN_CONFIGS_RECORD] + + def make(nid: int) -> Message: + neighbours = state.nid_to_neighbours[nid] & state.active_node_ids + cfgs_record = ConfigsRecord( + {str(nid): state.nid_to_publickeys[nid] for nid in neighbours} + ) + cfgs_record[Key.STAGE] = Stage.SHARE_KEYS + content = RecordSet(configs_records={RECORD_KEY_CONFIGS: cfgs_record}) + return driver.create_message( + content=content, + message_type=MessageType.TRAIN, + dst_node_id=nid, + group_id=str(cfg[WorkflowKey.CURRENT_ROUND]), + ttl="", + ) + + # Broadcast public keys to clients and receive secret key shares + log( + DEBUG, + "[Stage 1] Forwarding public keys to %s clients.", + len(state.active_node_ids), + ) + msgs = driver.send_and_receive( + [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout + ) + state.active_node_ids = { + msg.metadata.src_node_id for msg in msgs if not msg.has_error() + } + log( + DEBUG, + "[Stage 1] Received encrypted key shares from %s clients.", + len(state.active_node_ids), + ) + + # Build forward packet list dictionary + srcs: List[int] = [] + dsts: List[int] = [] + ciphertexts: List[bytes] = [] + fwd_ciphertexts: Dict[int, List[bytes]] = { + nid: [] for nid in state.active_node_ids + } # dest node ID -> list of ciphertexts + fwd_srcs: Dict[int, List[int]] = { + nid: [] for nid in state.active_node_ids + } # dest node ID -> list of src node IDs + for msg in msgs: + node_id = msg.metadata.src_node_id + res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS] + dst_lst = cast(List[int], res_dict[Key.DESTINATION_LIST]) + ctxt_lst = cast(List[bytes], res_dict[Key.CIPHERTEXT_LIST]) + srcs += [node_id] * len(dst_lst) + dsts += dst_lst + ciphertexts += ctxt_lst + + for src, dst, ciphertext in zip(srcs, dsts, ciphertexts): + if dst in fwd_ciphertexts: + fwd_ciphertexts[dst].append(ciphertext) + fwd_srcs[dst].append(src) + + state.forward_srcs = fwd_srcs + state.forward_ciphertexts = fwd_ciphertexts + + return self._check_threshold(state) + + def collect_masked_vectors_stage( + self, driver: Driver, context: LegacyContext, state: WorkflowState + ) -> bool: + """Execute the 'collect masked vectors' stage.""" + cfg = context.state.configs_records[MAIN_CONFIGS_RECORD] + + # Send secret key shares to clients (plus FitIns) and collect masked vectors + def make(nid: int) -> Message: + cfgs_dict = { + Key.STAGE: Stage.COLLECT_MASKED_VECTORS, + Key.CIPHERTEXT_LIST: state.forward_ciphertexts[nid], + Key.SOURCE_LIST: state.forward_srcs[nid], + } + cfgs_record = ConfigsRecord(cfgs_dict) # type: ignore + content = state.nid_to_fitins[nid] + content.configs_records[RECORD_KEY_CONFIGS] = cfgs_record + return driver.create_message( + content=content, + message_type=MessageType.TRAIN, + dst_node_id=nid, + group_id=str(cfg[WorkflowKey.CURRENT_ROUND]), + ttl="", + ) + + log( + DEBUG, + "[Stage 2] Forwarding encrypted key shares to %s clients.", + len(state.active_node_ids), + ) + msgs = driver.send_and_receive( + [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout + ) + state.active_node_ids = { + msg.metadata.src_node_id for msg in msgs if not msg.has_error() + } + log( + DEBUG, + "[Stage 2] Received masked vectors from %s clients.", + len(state.active_node_ids), + ) + + # Clear cache + del state.forward_ciphertexts, state.forward_srcs, state.nid_to_fitins + + # Sum collected masked vectors and compute active/dead node IDs + masked_vector = None + for msg in msgs: + res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS] + bytes_list = cast(List[bytes], res_dict[Key.MASKED_PARAMETERS]) + client_masked_vec = [bytes_to_ndarray(b) for b in bytes_list] + if masked_vector is None: + masked_vector = client_masked_vec + else: + masked_vector = parameters_addition(masked_vector, client_masked_vec) + if masked_vector is not None: + masked_vector = parameters_mod(masked_vector, state.mod_range) + state.aggregate_ndarrays = masked_vector + + return self._check_threshold(state) + + def unmask_stage( # pylint: disable=R0912, R0914, R0915 + self, driver: Driver, context: LegacyContext, state: WorkflowState + ) -> bool: + """Execute the 'unmask' stage.""" + cfg = context.state.configs_records[MAIN_CONFIGS_RECORD] + current_round = cast(int, cfg[WorkflowKey.CURRENT_ROUND]) + + # Construct active node IDs and dead node IDs + active_nids = state.active_node_ids + dead_nids = state.sampled_node_ids - active_nids + + # Send secure IDs of active and dead clients and collect key shares from clients + def make(nid: int) -> Message: + neighbours = state.nid_to_neighbours[nid] + cfgs_dict = { + Key.STAGE: Stage.UNMASK, + Key.ACTIVE_NODE_ID_LIST: list(neighbours & active_nids), + Key.DEAD_NODE_ID_LIST: list(neighbours & dead_nids), + } + cfgs_record = ConfigsRecord(cfgs_dict) # type: ignore + content = RecordSet(configs_records={RECORD_KEY_CONFIGS: cfgs_record}) + return driver.create_message( + content=content, + message_type=MessageType.TRAIN, + dst_node_id=nid, + group_id=str(current_round), + ttl="", + ) + + log( + DEBUG, + "[Stage 3] Requesting key shares from %s clients to remove masks.", + len(state.active_node_ids), + ) + msgs = driver.send_and_receive( + [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout + ) + state.active_node_ids = { + msg.metadata.src_node_id for msg in msgs if not msg.has_error() + } + log( + DEBUG, + "[Stage 3] Received key shares from %s clients.", + len(state.active_node_ids), + ) + + # Build collected shares dict + collected_shares_dict: Dict[int, List[bytes]] = {} + for nid in state.sampled_node_ids: + collected_shares_dict[nid] = [] + for msg in msgs: + res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS] + nids = cast(List[int], res_dict[Key.NODE_ID_LIST]) + shares = cast(List[bytes], res_dict[Key.SHARE_LIST]) + for owner_nid, share in zip(nids, shares): + collected_shares_dict[owner_nid].append(share) + + # Remove masks for every active client after collect_masked_vectors stage + masked_vector = state.aggregate_ndarrays + del state.aggregate_ndarrays + for nid, share_list in collected_shares_dict.items(): + if len(share_list) < state.threshold: + log( + ERROR, "Not enough shares to recover secret in unmask vectors stage" + ) + return False + secret = combine_shares(share_list) + if nid in active_nids: + # The seed for PRG is the private mask seed of an active client. + private_mask = pseudo_rand_gen( + secret, state.mod_range, get_parameters_shape(masked_vector) + ) + masked_vector = parameters_subtraction(masked_vector, private_mask) + else: + # The seed for PRG is the secret key 1 of a dropped client. + neighbours = state.nid_to_neighbours[nid] + neighbours.remove(nid) + + for neighbor_nid in neighbours: + shared_key = generate_shared_key( + bytes_to_private_key(secret), + bytes_to_public_key(state.nid_to_publickeys[neighbor_nid][0]), + ) + pairwise_mask = pseudo_rand_gen( + shared_key, state.mod_range, get_parameters_shape(masked_vector) + ) + if nid > neighbor_nid: + masked_vector = parameters_addition( + masked_vector, pairwise_mask + ) + else: + masked_vector = parameters_subtraction( + masked_vector, pairwise_mask + ) + recon_parameters = parameters_mod(masked_vector, state.mod_range) + q_total_ratio, recon_parameters = factor_extract(recon_parameters) + inv_dq_total_ratio = state.quantization_range / q_total_ratio + # recon_parameters = parameters_divide(recon_parameters, total_weights_factor) + aggregated_vector = dequantize( + recon_parameters, + state.clipping_range, + state.quantization_range, + ) + offset = -(len(active_nids) - 1) * state.clipping_range + for vec in aggregated_vector: + vec += offset + vec *= inv_dq_total_ratio + state.aggregate_ndarrays = aggregated_vector + + # No exception/failure handling currently + log( + INFO, + "aggregate_fit: received %s results and %s failures", + 1, + 0, + ) + + final_fitres = FitRes( + status=Status(code=Code.OK, message=""), + parameters=ndarrays_to_parameters(aggregated_vector), + num_examples=round(state.max_weight / inv_dq_total_ratio), + metrics={}, + ) + empty_proxy = DriverClientProxy( + 0, + driver.grpc_driver, # type: ignore + False, + driver.run_id, # type: ignore + ) + aggregated_result = context.strategy.aggregate_fit( + current_round, [(empty_proxy, final_fitres)], [] + ) + parameters_aggregated, metrics_aggregated = aggregated_result + + # Update the parameters and write history + if parameters_aggregated: + paramsrecord = compat.parameters_to_parametersrecord( + parameters_aggregated, True + ) + context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord + context.history.add_metrics_distributed_fit( + server_round=current_round, metrics=metrics_aggregated + ) + return True diff --git a/src/py/flwr/simulation/__init__.py b/src/py/flwr/simulation/__init__.py index 724ea9273916..d36d9977d1c5 100644 --- a/src/py/flwr/simulation/__init__.py +++ b/src/py/flwr/simulation/__init__.py @@ -17,6 +17,8 @@ import importlib +from flwr.simulation.run_simulation import run_simulation, run_simulation_from_cli + is_ray_installed = importlib.util.find_spec("ray") is not None if is_ray_installed: @@ -34,6 +36,4 @@ def start_simulation(*args, **kwargs): # type: ignore raise ImportError(RAY_IMPORT_ERROR) -__all__ = [ - "start_simulation", -] +__all__ = ["start_simulation", "run_simulation_from_cli", "run_simulation"] diff --git a/src/py/flwr/simulation/app.py b/src/py/flwr/simulation/app.py index 6a18a258ac60..ff18f37664be 100644 --- a/src/py/flwr/simulation/app.py +++ b/src/py/flwr/simulation/app.py @@ -28,13 +28,13 @@ from flwr.client import ClientFn from flwr.common import EventType, event from flwr.common.logger import log -from flwr.server import Server -from flwr.server.app import ServerConfig, init_defaults, run_fl from flwr.server.client_manager import ClientManager from flwr.server.history import History +from flwr.server.server import Server, init_defaults, run_fl +from flwr.server.server_config import ServerConfig from flwr.server.strategy import Strategy from flwr.simulation.ray_transport.ray_actor import ( - DefaultActor, + ClientAppActor, VirtualClientEngineActor, VirtualClientEngineActorPool, pool_size_from_resources, @@ -82,7 +82,7 @@ def start_simulation( client_manager: Optional[ClientManager] = None, ray_init_args: Optional[Dict[str, Any]] = None, keep_initialised: Optional[bool] = False, - actor_type: Type[VirtualClientEngineActor] = DefaultActor, + actor_type: Type[VirtualClientEngineActor] = ClientAppActor, actor_kwargs: Optional[Dict[str, Any]] = None, actor_scheduling: Union[str, NodeAffinitySchedulingStrategy] = "DEFAULT", ) -> History: @@ -138,10 +138,10 @@ def start_simulation( keep_initialised: Optional[bool] (default: False) Set to True to prevent `ray.shutdown()` in case `ray.is_initialized()=True`. - actor_type: VirtualClientEngineActor (default: DefaultActor) + actor_type: VirtualClientEngineActor (default: ClientAppActor) Optionally specify the type of actor to use. The actor object, which persists throughout the simulation, will be the process in charge of - running the clients' jobs (i.e. their `fit()` method). + executing a ClientApp wrapping input argument `client_fn`. actor_kwargs: Optional[Dict[str, Any]] (default: None) If you want to create your own Actor classes, you might need to pass @@ -219,7 +219,7 @@ def start_simulation( log( INFO, "Optimize your simulation with Flower VCE: " - "https://flower.dev/docs/framework/how-to-run-simulations.html", + "https://flower.ai/docs/framework/how-to-run-simulations.html", ) # Log the resources that a single client will be able to use @@ -337,7 +337,7 @@ def update_resources(f_stop: threading.Event) -> None: "disconnected. The head node might still be alive but cannot accommodate " "any actor with resources: %s." "\nTake a look at the Flower simulation examples for guidance " - ".", + ".", client_resources, client_resources, ) diff --git a/src/py/flwr/simulation/ray_transport/ray_actor.py b/src/py/flwr/simulation/ray_transport/ray_actor.py index 38af3f08daa2..08d0576e39f0 100644 --- a/src/py/flwr/simulation/ray_transport/ray_actor.py +++ b/src/py/flwr/simulation/ray_transport/ray_actor.py @@ -14,29 +14,22 @@ # ============================================================================== """Ray-based Flower Actor and ActorPool implementation.""" - +import asyncio import threading import traceback from abc import ABC -from logging import ERROR, WARNING +from logging import DEBUG, ERROR, WARNING from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union import ray from ray import ObjectRef from ray.util.actor_pool import ActorPool -from flwr import common -from flwr.client import Client, ClientFn -from flwr.client.run_state import RunState +from flwr.client.client_app import ClientApp, LoadClientAppError +from flwr.common import Context, Message from flwr.common.logger import log -from flwr.simulation.ray_transport.utils import check_clientfn_returns_client -# All possible returns by a client -ClientRes = Union[ - common.GetPropertiesRes, common.GetParametersRes, common.FitRes, common.EvaluateRes -] -# A function to be executed by a client to obtain some results -JobFn = Callable[[Client], ClientRes] +ClientAppFn = Callable[[], ClientApp] class ClientException(Exception): @@ -53,32 +46,33 @@ class VirtualClientEngineActor(ABC): def terminate(self) -> None: """Manually terminate Actor object.""" - log(WARNING, "Manually terminating %s}", self.__class__.__name__) + log(WARNING, "Manually terminating %s", self.__class__.__name__) ray.actor.exit_actor() def run( self, - client_fn: ClientFn, - job_fn: JobFn, + client_app_fn: ClientAppFn, + message: Message, cid: str, - state: RunState, - ) -> Tuple[str, ClientRes, RunState]: + context: Context, + ) -> Tuple[str, Message, Context]: """Run a client run.""" - # Execute tasks and return result + # Pass message through ClientApp and return a message # return also cid which is needed to ensure results # from the pool are correctly assigned to each ClientProxy try: - # Instantiate client (check 'Client' type is returned) - client = check_clientfn_returns_client(client_fn(cid)) - # Inject state - client.set_state(state) - # Run client job - job_results = job_fn(client) - # Retrieve state (potentially updated) - updated_state = client.get_state() + # Load app + app: ClientApp = client_app_fn() + + # Handle task message + out_message = app(message=message, context=context) + + except LoadClientAppError as load_ex: + raise load_ex + except Exception as ex: client_trace = traceback.format_exc() - message = ( + mssg = ( "\n\tSomething went wrong when running your client run." "\n\tClient " + cid @@ -87,13 +81,13 @@ def run( + " was running its run." "\n\tException triggered on the client side: " + client_trace, ) - raise ClientException(str(message)) from ex + raise ClientException(str(mssg)) from ex - return cid, job_results, updated_state + return cid, out_message, context @ray.remote -class DefaultActor(VirtualClientEngineActor): +class ClientAppActor(VirtualClientEngineActor): """A Ray Actor class that runs client runs. Parameters @@ -237,16 +231,16 @@ def add_actors_to_pool(self, num_actors: int) -> None: self._idle_actors.extend(new_actors) self.num_actors += num_actors - def submit(self, fn: Any, value: Tuple[ClientFn, JobFn, str, RunState]) -> None: - """Take idle actor and assign it a client run. + def submit(self, fn: Any, value: Tuple[ClientAppFn, Message, str, Context]) -> None: + """Take an idle actor and assign it to run a client app and Message. Submit a job to an actor by first removing it from the list of idle actors, then - check if this actor was flagged to be removed from the pool + check if this actor was flagged to be removed from the pool. """ - client_fn, job_fn, cid, state = value + app_fn, mssg, cid, context = value actor = self._idle_actors.pop() if self._check_and_remove_actor_from_pool(actor): - future = fn(actor, client_fn, job_fn, cid, state) + future = fn(actor, app_fn, mssg, cid, context) future_key = tuple(future) if isinstance(future, List) else future self._future_to_actor[future_key] = (self._next_task_index, actor, cid) self._next_task_index += 1 @@ -255,7 +249,7 @@ def submit(self, fn: Any, value: Tuple[ClientFn, JobFn, str, RunState]) -> None: self._cid_to_future[cid]["future"] = future_key def submit_client_job( - self, actor_fn: Any, job: Tuple[ClientFn, JobFn, str, RunState] + self, actor_fn: Any, job: Tuple[ClientAppFn, Message, str, Context] ) -> None: """Submit a job while tracking client ids.""" _, _, cid, _ = job @@ -295,17 +289,17 @@ def _is_future_ready(self, cid: str) -> bool: return self._cid_to_future[cid]["ready"] # type: ignore - def _fetch_future_result(self, cid: str) -> Tuple[ClientRes, RunState]: - """Fetch result and updated state for a VirtualClient from Object Store. + def _fetch_future_result(self, cid: str) -> Tuple[Message, Context]: + """Fetch result and updated context for a VirtualClient from Object Store. The job submitted by the ClientProxy interfacing with client with cid=cid is ready. Here we fetch it from the object store and return. """ try: future: ObjectRef[Any] = self._cid_to_future[cid]["future"] # type: ignore - res_cid, res, updated_state = ray.get( + res_cid, out_mssg, updated_context = ray.get( future - ) # type: (str, ClientRes, RunState) + ) # type: (str, Message, Context) except ray.exceptions.RayActorError as ex: log(ERROR, ex) if hasattr(ex, "actor_id"): @@ -322,7 +316,7 @@ def _fetch_future_result(self, cid: str) -> Tuple[ClientRes, RunState]: # Reset mapping self._reset_cid_to_future_dict(cid) - return res, updated_state + return out_mssg, updated_context def _flag_actor_for_removal(self, actor_id_hex: str) -> None: """Flag actor that should be removed from pool.""" @@ -409,7 +403,7 @@ def process_unordered_future(self, timeout: Optional[float] = None) -> None: def get_client_result( self, cid: str, timeout: Optional[float] - ) -> Tuple[ClientRes, RunState]: + ) -> Tuple[Message, Context]: """Get result from VirtualClient with specific cid.""" # Loop until all jobs submitted to the pool are completed. Break early # if the result for the ClientProxy calling this method is ready @@ -421,5 +415,92 @@ def get_client_result( break # Fetch result belonging to the VirtualClient calling this method - # Return both result from tasks and (potentially) updated run state + # Return both result from tasks and (potentially) updated run context return self._fetch_future_result(cid) + + +def init_ray(*args: Any, **kwargs: Any) -> None: + """Intialises Ray if not already initialised.""" + if not ray.is_initialized(): + ray.init(*args, **kwargs) + + +class BasicActorPool: + """A basic actor pool.""" + + def __init__( + self, + actor_type: Type[VirtualClientEngineActor], + client_resources: Dict[str, Union[int, float]], + actor_kwargs: Dict[str, Any], + ): + self.client_resources = client_resources + + # Queue of idle actors + self.pool: "asyncio.Queue[Type[VirtualClientEngineActor]]" = asyncio.Queue( + maxsize=1024 + ) + self.num_actors = 0 + + # Resolve arguments to pass during actor init + actor_args = {} if actor_kwargs is None else actor_kwargs + + # A function that creates an actor + self.create_actor_fn = lambda: actor_type.options( # type: ignore + **client_resources + ).remote(**actor_args) + + # Figure out how many actors can be created given the cluster resources + # and the resources the user indicates each VirtualClient will need + self.actors_capacity = pool_size_from_resources(client_resources) + self._future_to_actor: Dict[Any, Type[VirtualClientEngineActor]] = {} + + def is_actor_available(self) -> bool: + """Return true if there is an idle actor.""" + return self.pool.qsize() > 0 + + async def add_actors_to_pool(self, num_actors: int) -> None: + """Add actors to the pool. + + This method may be executed also if new resources are added to your Ray cluster + (e.g. you add a new node). + """ + for _ in range(num_actors): + await self.pool.put(self.create_actor_fn()) # type: ignore + self.num_actors += num_actors + + async def terminate_all_actors(self) -> None: + """Terminate actors in pool.""" + num_terminated = 0 + while self.pool.qsize(): + actor = await self.pool.get() + actor.terminate.remote() # type: ignore + num_terminated += 1 + + log(DEBUG, "Terminated %i actors", num_terminated) + + async def submit( + self, actor_fn: Any, job: Tuple[ClientAppFn, Message, str, Context] + ) -> Any: + """On idle actor, submit job and return future.""" + # Remove idle actor from pool + actor = await self.pool.get() + # Submit job to actor + app_fn, mssg, cid, context = job + future = actor_fn(actor, app_fn, mssg, cid, context) + # Keep track of future:actor (so we can fetch the actor upon job completion + # and add it back to the pool) + self._future_to_actor[future] = actor + return future + + async def fetch_result_and_return_actor_to_pool( + self, future: Any + ) -> Tuple[Message, Context]: + """Pull result given a future and add actor back to pool.""" + # Get actor that ran job + actor = self._future_to_actor.pop(future) + await self.pool.put(actor) + # Retrieve result for object store + # Instead of doing ray.get(future) we await it + _, out_mssg, updated_context = await future + return out_mssg, updated_context diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py index 5c05850dfd2f..c3493163ac52 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py @@ -17,107 +17,27 @@ import traceback from logging import ERROR -from typing import Dict, Optional, cast - -import ray +from typing import Optional from flwr import common -from flwr.client import Client, ClientFn -from flwr.client.client import ( - maybe_call_evaluate, - maybe_call_fit, - maybe_call_get_parameters, - maybe_call_get_properties, -) +from flwr.client import ClientFn +from flwr.client.client_app import ClientApp from flwr.client.node_state import NodeState +from flwr.common import Message, Metadata, RecordSet +from flwr.common.constant import MessageType, MessageTypeLegacy from flwr.common.logger import log -from flwr.server.client_proxy import ClientProxy -from flwr.simulation.ray_transport.ray_actor import ( - ClientRes, - JobFn, - VirtualClientEngineActorPool, +from flwr.common.recordset_compat import ( + evaluateins_to_recordset, + fitins_to_recordset, + getparametersins_to_recordset, + getpropertiesins_to_recordset, + recordset_to_evaluateres, + recordset_to_fitres, + recordset_to_getparametersres, + recordset_to_getpropertiesres, ) - - -class RayClientProxy(ClientProxy): - """Flower client proxy which delegates work using Ray.""" - - def __init__(self, client_fn: ClientFn, cid: str, resources: Dict[str, float]): - super().__init__(cid) - self.client_fn = client_fn - self.resources = resources - - def get_properties( - self, ins: common.GetPropertiesIns, timeout: Optional[float] - ) -> common.GetPropertiesRes: - """Return client's properties.""" - future_get_properties_res = launch_and_get_properties.options( # type: ignore - **self.resources, - ).remote(self.client_fn, self.cid, ins) - try: - res = ray.get(future_get_properties_res, timeout=timeout) - except Exception as ex: - log(ERROR, ex) - raise ex - return cast( - common.GetPropertiesRes, - res, - ) - - def get_parameters( - self, ins: common.GetParametersIns, timeout: Optional[float] - ) -> common.GetParametersRes: - """Return the current local model parameters.""" - future_paramseters_res = launch_and_get_parameters.options( # type: ignore - **self.resources, - ).remote(self.client_fn, self.cid, ins) - try: - res = ray.get(future_paramseters_res, timeout=timeout) - except Exception as ex: - log(ERROR, ex) - raise ex - return cast( - common.GetParametersRes, - res, - ) - - def fit(self, ins: common.FitIns, timeout: Optional[float]) -> common.FitRes: - """Train model parameters on the locally held dataset.""" - future_fit_res = launch_and_fit.options( # type: ignore - **self.resources, - ).remote(self.client_fn, self.cid, ins) - try: - res = ray.get(future_fit_res, timeout=timeout) - except Exception as ex: - log(ERROR, ex) - raise ex - return cast( - common.FitRes, - res, - ) - - def evaluate( - self, ins: common.EvaluateIns, timeout: Optional[float] - ) -> common.EvaluateRes: - """Evaluate model parameters on the locally held dataset.""" - future_evaluate_res = launch_and_evaluate.options( # type: ignore - **self.resources, - ).remote(self.client_fn, self.cid, ins) - try: - res = ray.get(future_evaluate_res, timeout=timeout) - except Exception as ex: - log(ERROR, ex) - raise ex - return cast( - common.EvaluateRes, - res, - ) - - def reconnect( - self, ins: common.ReconnectIns, timeout: Optional[float] - ) -> common.DisconnectRes: - """Disconnect and (optionally) reconnect later.""" - return common.DisconnectRes(reason="") # Nothing to do here (yet) +from flwr.server.client_proxy import ClientProxy +from flwr.simulation.ray_transport.ray_actor import VirtualClientEngineActorPool class RayActorClientProxy(ClientProxy): @@ -127,31 +47,35 @@ def __init__( self, client_fn: ClientFn, cid: str, actor_pool: VirtualClientEngineActorPool ): super().__init__(cid) - self.client_fn = client_fn + + def _load_app() -> ClientApp: + return ClientApp(client_fn=client_fn) + + self.app_fn = _load_app self.actor_pool = actor_pool self.proxy_state = NodeState() - def _submit_job(self, job_fn: JobFn, timeout: Optional[float]) -> ClientRes: - # The VCE is not exposed to TaskIns, it won't handle multilple runs - # For the time being, fixing run_id is a small compromise - # This will be one of the first points to address integrating VCE + DriverAPI - run_id = 0 + def _submit_job(self, message: Message, timeout: Optional[float]) -> Message: + """Sumbit a message to the ActorPool.""" + run_id = message.metadata.run_id # Register state - self.proxy_state.register_runstate(run_id=run_id) + self.proxy_state.register_context(run_id=run_id) # Retrieve state - state = self.proxy_state.retrieve_runstate(run_id=run_id) + state = self.proxy_state.retrieve_context(run_id=run_id) try: self.actor_pool.submit_client_job( - lambda a, c_fn, j_fn, cid, state: a.run.remote(c_fn, j_fn, cid, state), - (self.client_fn, job_fn, self.cid, state), + lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state), + (self.app_fn, message, self.cid, state), + ) + out_mssg, updated_context = self.actor_pool.get_client_result( + self.cid, timeout ) - res, updated_state = self.actor_pool.get_client_result(self.cid, timeout) # Update state - self.proxy_state.update_runstate(run_id=run_id, run_state=updated_state) + self.proxy_state.update_context(run_id=run_id, context=updated_context) except Exception as ex: if self.actor_pool.num_actors == 0: @@ -162,134 +86,110 @@ def _submit_job(self, job_fn: JobFn, timeout: Optional[float]) -> ClientRes: log(ERROR, ex) raise ex - return res + return out_mssg + + def _wrap_recordset_in_message( + self, + recordset: RecordSet, + message_type: str, + timeout: Optional[float], + group_id: Optional[int], + ) -> Message: + """Wrap a RecordSet inside a Message.""" + return Message( + content=recordset, + metadata=Metadata( + run_id=0, + message_id="", + group_id=str(group_id) if group_id is not None else "", + src_node_id=0, + dst_node_id=int(self.cid), + reply_to_message="", + ttl=str(timeout) if timeout else "", + message_type=message_type, + partition_id=int(self.cid), + ), + ) def get_properties( - self, ins: common.GetPropertiesIns, timeout: Optional[float] + self, + ins: common.GetPropertiesIns, + timeout: Optional[float], + group_id: Optional[int], ) -> common.GetPropertiesRes: """Return client's properties.""" + recordset = getpropertiesins_to_recordset(ins) + message = self._wrap_recordset_in_message( + recordset, + message_type=MessageTypeLegacy.GET_PROPERTIES, + timeout=timeout, + group_id=group_id, + ) - def get_properties(client: Client) -> common.GetPropertiesRes: - return maybe_call_get_properties( - client=client, - get_properties_ins=ins, - ) - - res = self._submit_job(get_properties, timeout) + message_out = self._submit_job(message, timeout) - return cast( - common.GetPropertiesRes, - res, - ) + return recordset_to_getpropertiesres(message_out.content) def get_parameters( - self, ins: common.GetParametersIns, timeout: Optional[float] + self, + ins: common.GetParametersIns, + timeout: Optional[float], + group_id: Optional[int], ) -> common.GetParametersRes: """Return the current local model parameters.""" + recordset = getparametersins_to_recordset(ins) + message = self._wrap_recordset_in_message( + recordset, + message_type=MessageTypeLegacy.GET_PARAMETERS, + timeout=timeout, + group_id=group_id, + ) - def get_parameters(client: Client) -> common.GetParametersRes: - return maybe_call_get_parameters( - client=client, - get_parameters_ins=ins, - ) + message_out = self._submit_job(message, timeout) - res = self._submit_job(get_parameters, timeout) + return recordset_to_getparametersres(message_out.content, keep_input=False) - return cast( - common.GetParametersRes, - res, - ) - - def fit(self, ins: common.FitIns, timeout: Optional[float]) -> common.FitRes: + def fit( + self, ins: common.FitIns, timeout: Optional[float], group_id: Optional[int] + ) -> common.FitRes: """Train model parameters on the locally held dataset.""" + recordset = fitins_to_recordset( + ins, keep_input=True + ) # This must stay TRUE since ins are in-memory + message = self._wrap_recordset_in_message( + recordset, + message_type=MessageType.TRAIN, + timeout=timeout, + group_id=group_id, + ) - def fit(client: Client) -> common.FitRes: - return maybe_call_fit( - client=client, - fit_ins=ins, - ) - - res = self._submit_job(fit, timeout) + message_out = self._submit_job(message, timeout) - return cast( - common.FitRes, - res, - ) + return recordset_to_fitres(message_out.content, keep_input=False) def evaluate( - self, ins: common.EvaluateIns, timeout: Optional[float] + self, ins: common.EvaluateIns, timeout: Optional[float], group_id: Optional[int] ) -> common.EvaluateRes: """Evaluate model parameters on the locally held dataset.""" + recordset = evaluateins_to_recordset( + ins, keep_input=True + ) # This must stay TRUE since ins are in-memory + message = self._wrap_recordset_in_message( + recordset, + message_type=MessageType.EVALUATE, + timeout=timeout, + group_id=group_id, + ) - def evaluate(client: Client) -> common.EvaluateRes: - return maybe_call_evaluate( - client=client, - evaluate_ins=ins, - ) + message_out = self._submit_job(message, timeout) - res = self._submit_job(evaluate, timeout) - - return cast( - common.EvaluateRes, - res, - ) + return recordset_to_evaluateres(message_out.content) def reconnect( - self, ins: common.ReconnectIns, timeout: Optional[float] + self, + ins: common.ReconnectIns, + timeout: Optional[float], + group_id: Optional[int], ) -> common.DisconnectRes: """Disconnect and (optionally) reconnect later.""" return common.DisconnectRes(reason="") # Nothing to do here (yet) - - -@ray.remote -def launch_and_get_properties( - client_fn: ClientFn, cid: str, get_properties_ins: common.GetPropertiesIns -) -> common.GetPropertiesRes: - """Exectue get_properties remotely.""" - client: Client = _create_client(client_fn, cid) - return maybe_call_get_properties( - client=client, - get_properties_ins=get_properties_ins, - ) - - -@ray.remote -def launch_and_get_parameters( - client_fn: ClientFn, cid: str, get_parameters_ins: common.GetParametersIns -) -> common.GetParametersRes: - """Exectue get_parameters remotely.""" - client: Client = _create_client(client_fn, cid) - return maybe_call_get_parameters( - client=client, - get_parameters_ins=get_parameters_ins, - ) - - -@ray.remote -def launch_and_fit( - client_fn: ClientFn, cid: str, fit_ins: common.FitIns -) -> common.FitRes: - """Exectue fit remotely.""" - client: Client = _create_client(client_fn, cid) - return maybe_call_fit( - client=client, - fit_ins=fit_ins, - ) - - -@ray.remote -def launch_and_evaluate( - client_fn: ClientFn, cid: str, evaluate_ins: common.EvaluateIns -) -> common.EvaluateRes: - """Exectue evaluate remotely.""" - client: Client = _create_client(client_fn, cid) - return maybe_call_evaluate( - client=client, - evaluate_ins=evaluate_ins, - ) - - -def _create_client(client_fn: ClientFn, cid: str) -> Client: - """Create a client instance.""" - # Materialize client - return client_fn(cid) diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py index 9df71635b949..22c5425cd9fd 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py @@ -17,17 +17,29 @@ from math import pi from random import shuffle -from typing import List, Tuple, Type, cast +from typing import Dict, List, Tuple, Type import ray from flwr.client import Client, NumPyClient -from flwr.client.run_state import RunState -from flwr.common import Code, GetPropertiesRes, Status +from flwr.client.client_app import ClientApp +from flwr.common import ( + Config, + ConfigsRecord, + Context, + Message, + MessageTypeLegacy, + Metadata, + RecordSet, + Scalar, +) +from flwr.common.recordset_compat import ( + getpropertiesins_to_recordset, + recordset_to_getpropertiesres, +) +from flwr.common.recordset_compat_test import _get_valid_getpropertiesins from flwr.simulation.ray_transport.ray_actor import ( - ClientRes, - DefaultActor, - JobFn, + ClientAppActor, VirtualClientEngineActor, VirtualClientEngineActorPool, ) @@ -40,32 +52,24 @@ class DummyClient(NumPyClient): def __init__(self, cid: str) -> None: self.cid = int(cid) + def get_properties(self, config: Config) -> Dict[str, Scalar]: + """Return properties by doing a simple calculation.""" + result = int(self.cid) * pi + + # store something in context + self.context.state.configs_records["result"] = ConfigsRecord( + {"result": str(result)} + ) + return {"result": result} + def get_dummy_client(cid: str) -> Client: """Return a DummyClient converted to Client type.""" return DummyClient(cid).to_client() -# A dummy run -def job_fn(cid: str) -> JobFn: # pragma: no cover - """Construct a simple job with cid dependency.""" - - def cid_times_pi(client: Client) -> ClientRes: # pylint: disable=unused-argument - result = int(cid) * pi - - # store something in state - client.numpy_client.state.state["result"] = str(result) # type: ignore - - # now let's convert it to a GetPropertiesRes response - return GetPropertiesRes( - status=Status(Code(0), message="test"), properties={"result": result} - ) - - return cid_times_pi - - def prep( - actor_type: Type[VirtualClientEngineActor] = DefaultActor, + actor_type: Type[VirtualClientEngineActor] = ClientAppActor, ) -> Tuple[List[RayActorClientProxy], VirtualClientEngineActorPool]: # pragma: no cover """Prepare ClientProxies and pool for tests.""" client_resources = {"num_cpus": 1, "num_gpus": 0.0} @@ -100,13 +104,24 @@ def test_cid_consistency_one_at_a_time() -> None: Submit one job and waits for completion. Then submits the next and so on """ proxies, _ = prep() + + getproperties_ins = _get_valid_getpropertiesins() + recordset = getpropertiesins_to_recordset(getproperties_ins) + # submit jobs one at a time for prox in proxies: - res = prox._submit_job( # pylint: disable=protected-access - job_fn=job_fn(prox.cid), timeout=None + message = prox._wrap_recordset_in_message( # pylint: disable=protected-access + recordset, + MessageTypeLegacy.GET_PROPERTIES, + timeout=None, + group_id=0, + ) + message_out = prox._submit_job( # pylint: disable=protected-access + message=message, timeout=None ) - res = cast(GetPropertiesRes, res) + res = recordset_to_getpropertiesres(message_out.content) + assert int(prox.cid) * pi == res.properties["result"] ray.shutdown() @@ -121,30 +136,43 @@ def test_cid_consistency_all_submit_first_run_consistency() -> None: proxies, _ = prep() run_id = 0 + getproperties_ins = _get_valid_getpropertiesins() + recordset = getpropertiesins_to_recordset(getproperties_ins) + # submit all jobs (collect later) shuffle(proxies) for prox in proxies: # Register state - prox.proxy_state.register_runstate(run_id=run_id) + prox.proxy_state.register_context(run_id=run_id) # Retrieve state - state = prox.proxy_state.retrieve_runstate(run_id=run_id) + state = prox.proxy_state.retrieve_context(run_id=run_id) - job = job_fn(prox.cid) + message = prox._wrap_recordset_in_message( # pylint: disable=protected-access + recordset, + message_type=MessageTypeLegacy.GET_PROPERTIES, + timeout=None, + group_id=0, + ) prox.actor_pool.submit_client_job( - lambda a, c_fn, j_fn, cid, state: a.run.remote(c_fn, j_fn, cid, state), - (prox.client_fn, job, prox.cid, state), + lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state), + (prox.app_fn, message, prox.cid, state), ) # fetch results one at a time shuffle(proxies) for prox in proxies: - res, updated_state = prox.actor_pool.get_client_result(prox.cid, timeout=None) - prox.proxy_state.update_runstate(run_id, run_state=updated_state) - res = cast(GetPropertiesRes, res) + message_out, updated_context = prox.actor_pool.get_client_result( + prox.cid, timeout=None + ) + prox.proxy_state.update_context(run_id, context=updated_context) + res = recordset_to_getpropertiesres(message_out.content) + assert int(prox.cid) * pi == res.properties["result"] assert ( str(int(prox.cid) * pi) - == prox.proxy_state.retrieve_runstate(run_id).state["result"] + == prox.proxy_state.retrieve_context(run_id).state.configs_records[ + "result" + ]["result"] ) ray.shutdown() @@ -156,20 +184,39 @@ def test_cid_consistency_without_proxies() -> None: num_clients = len(proxies) cids = [str(cid) for cid in range(num_clients)] + getproperties_ins = _get_valid_getpropertiesins() + recordset = getpropertiesins_to_recordset(getproperties_ins) + + def _load_app() -> ClientApp: + return ClientApp(client_fn=get_dummy_client) + # submit all jobs (collect later) shuffle(cids) for cid in cids: - job = job_fn(cid) + message = Message( + content=recordset, + metadata=Metadata( + run_id=0, + message_id="", + group_id=str(0), + src_node_id=0, + dst_node_id=12345, + reply_to_message="", + ttl="", + message_type=MessageTypeLegacy.GET_PROPERTIES, + partition_id=int(cid), + ), + ) pool.submit_client_job( lambda a, c_fn, j_fn, cid_, state: a.run.remote(c_fn, j_fn, cid_, state), - (get_dummy_client, job, cid, RunState(state={})), + (_load_app, message, cid, Context(state=RecordSet())), ) # fetch results one at a time shuffle(cids) for cid in cids: - res, _ = pool.get_client_result(cid, timeout=None) - res = cast(GetPropertiesRes, res) + message_out, _ = pool.get_client_result(cid, timeout=None) + res = recordset_to_getpropertiesres(message_out.content) assert int(cid) * pi == res.properties["result"] ray.shutdown() diff --git a/src/py/flwr/simulation/ray_transport/utils.py b/src/py/flwr/simulation/ray_transport/utils.py index 41aa8049eaf0..3861164998a4 100644 --- a/src/py/flwr/simulation/ray_transport/utils.py +++ b/src/py/flwr/simulation/ray_transport/utils.py @@ -15,9 +15,9 @@ """Utilities for Actors in the Virtual Client Engine.""" import traceback +import warnings from logging import ERROR -from flwr.client import Client from flwr.common.logger import log try: @@ -26,7 +26,7 @@ TF = None # Display Deprecation warning once -# warnings.filterwarnings("once", category=DeprecationWarning) +warnings.filterwarnings("once", category=DeprecationWarning) def enable_tf_gpu_growth() -> None: @@ -59,25 +59,3 @@ def enable_tf_gpu_growth() -> None: log(ERROR, traceback.format_exc()) log(ERROR, ex) raise ex - - -def check_clientfn_returns_client(client: Client) -> Client: - """Warn once that clients returned in `clinet_fn` should be of type Client. - - This is here for backwards compatibility. If a ClientFn is provided returning - a different type of client (e.g. NumPyClient) we'll warn the user but convert - the client internally to `Client` by calling `.to_client()`. - """ - if not isinstance(client, Client): - # mssg = ( - # " Ensure your client is of type `Client`. Please convert it" - # " using the `.to_client()` method before returning it" - # " in the `client_fn` you pass to `start_simulation`." - # " We have applied this conversion on your behalf." - # " Not returning a `Client` might trigger an error in future" - # " versions of Flower." - # ) - - # warnings.warn(mssg, DeprecationWarning, stacklevel=2) - client = client.to_client() - return client diff --git a/src/py/flwr/simulation/run_simulation.py b/src/py/flwr/simulation/run_simulation.py new file mode 100644 index 000000000000..31884f2edc68 --- /dev/null +++ b/src/py/flwr/simulation/run_simulation.py @@ -0,0 +1,441 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower Simulation.""" + +import argparse +import asyncio +import json +import logging +import threading +import traceback +from logging import DEBUG, ERROR, INFO, WARNING +from time import sleep +from typing import Dict, Optional + +import grpc + +from flwr.client import ClientApp +from flwr.common import EventType, event, log +from flwr.common.typing import ConfigsRecordValues +from flwr.server.driver.driver import Driver +from flwr.server.run_serverapp import run +from flwr.server.server_app import ServerApp +from flwr.server.superlink.driver.driver_grpc import run_driver_api_grpc +from flwr.server.superlink.fleet import vce +from flwr.server.superlink.state import StateFactory +from flwr.simulation.ray_transport.utils import ( + enable_tf_gpu_growth as enable_gpu_growth, +) + + +# Entry point from CLI +def run_simulation_from_cli() -> None: + """Run Simulation Engine from the CLI.""" + args = _parse_args_run_simulation().parse_args() + + # Load JSON config + backend_config_dict = json.loads(args.backend_config) + + _run_simulation( + server_app_attr=args.server_app, + client_app_attr=args.client_app, + num_supernodes=args.num_supernodes, + backend_name=args.backend, + backend_config=backend_config_dict, + app_dir=args.app_dir, + driver_api_address=args.driver_api_address, + enable_tf_gpu_growth=args.enable_tf_gpu_growth, + verbose_logging=args.verbose, + ) + + +# Entry point from Python session (script or notebook) +# pylint: disable=too-many-arguments +def run_simulation( + server_app: ServerApp, + client_app: ClientApp, + num_supernodes: int, + backend_name: str = "ray", + backend_config: Optional[Dict[str, ConfigsRecordValues]] = None, + enable_tf_gpu_growth: bool = False, + verbose_logging: bool = False, +) -> None: + r"""Run a Flower App using the Simulation Engine. + + Parameters + ---------- + server_app : ServerApp + The `ServerApp` to be executed. It will send messages to different `ClientApp` + instances running on different (virtual) SuperNodes. + + client_app : ClientApp + The `ClientApp` to be executed by each of the SuperNodes. It will receive + messages sent by the `ServerApp`. + + num_supernodes : int + Number of nodes that run a ClientApp. They can be sampled by a + Driver in the ServerApp and receive a Message describing what the ClientApp + should perform. + + backend_name : str (default: ray) + A simulation backend that runs `ClientApp`s. + + backend_config : Optional[Dict[str, ConfigsRecordValues]] + 'A dictionary, e.g {"": , "": } to configure a + backend. Values supported in are those included by + `flwr.common.typing.ConfigsRecordValues`. + + enable_tf_gpu_growth : bool (default: False) + A boolean to indicate whether to enable GPU growth on the main thread. This is + desirable if you make use of a TensorFlow model on your `ServerApp` while + having your `ClientApp` running on the same GPU. Without enabling this, you + might encounter an out-of-memory error because TensorFlow, by default, allocates + all GPU memory. Read more about how `tf.config.experimental.set_memory_growth()` + works in the TensorFlow documentation: https://www.tensorflow.org/api/stable. + + verbose_logging : bool (default: False) + When diabled, only INFO, WARNING and ERROR log messages will be shown. If + enabled, DEBUG-level logs will be displayed. + """ + _run_simulation( + num_supernodes=num_supernodes, + client_app=client_app, + server_app=server_app, + backend_name=backend_name, + backend_config=backend_config, + enable_tf_gpu_growth=enable_tf_gpu_growth, + verbose_logging=verbose_logging, + ) + + +# pylint: disable=too-many-arguments +def run_serverapp_th( + server_app_attr: Optional[str], + server_app: Optional[ServerApp], + driver: Driver, + app_dir: str, + f_stop: asyncio.Event, + enable_tf_gpu_growth: bool, + delay_launch: int = 3, +) -> threading.Thread: + """Run SeverApp in a thread.""" + + def server_th_with_start_checks( # type: ignore + tf_gpu_growth: bool, stop_event: asyncio.Event, **kwargs + ) -> None: + """Run SeverApp, after check if GPU memory grouwth has to be set. + + Upon exception, trigger stop event for Simulation Engine. + """ + try: + if tf_gpu_growth: + log(INFO, "Enabling GPU growth for Tensorflow on the main thread.") + enable_gpu_growth() + + # Run ServerApp + run(**kwargs) + except Exception as ex: # pylint: disable=broad-exception-caught + log(ERROR, "ServerApp thread raised an exception: %s", ex) + log(ERROR, traceback.format_exc()) + finally: + log(DEBUG, "ServerApp finished running.") + # Upon completion, trigger stop event if one was passed + if stop_event is not None: + stop_event.set() + log(WARNING, "Triggered stop event for Simulation Engine.") + + serverapp_th = threading.Thread( + target=server_th_with_start_checks, + args=(enable_tf_gpu_growth, f_stop), + kwargs={ + "server_app_attr": server_app_attr, + "loaded_server_app": server_app, + "driver": driver, + "server_app_dir": app_dir, + }, + ) + sleep(delay_launch) + serverapp_th.start() + return serverapp_th + + +# pylint: disable=too-many-locals +def _main_loop( + num_supernodes: int, + backend_name: str, + backend_config_stream: str, + driver_api_address: str, + app_dir: str, + enable_tf_gpu_growth: bool, + client_app: Optional[ClientApp] = None, + client_app_attr: Optional[str] = None, + server_app: Optional[ServerApp] = None, + server_app_attr: Optional[str] = None, +) -> None: + """Launch SuperLink with Simulation Engine, then ServerApp on a separate thread. + + Everything runs on the main thread or a separate one, depening on whether the main + thread already contains a running Asyncio event loop. This is the case if running + the Simulation Engine on a Jupyter/Colab notebook. + """ + # Initialize StateFactory + state_factory = StateFactory(":flwr-in-memory-state:") + + # Start Driver API + driver_server: grpc.Server = run_driver_api_grpc( + address=driver_api_address, + state_factory=state_factory, + certificates=None, + ) + + f_stop = asyncio.Event() + serverapp_th = None + try: + # Initialize Driver + driver = Driver( + driver_service_address=driver_api_address, + root_certificates=None, + ) + + # Get and run ServerApp thread + serverapp_th = run_serverapp_th( + server_app_attr=server_app_attr, + server_app=server_app, + driver=driver, + app_dir=app_dir, + f_stop=f_stop, + enable_tf_gpu_growth=enable_tf_gpu_growth, + ) + + # SuperLink with Simulation Engine + event(EventType.RUN_SUPERLINK_ENTER) + vce.start_vce( + num_supernodes=num_supernodes, + client_app_attr=client_app_attr, + client_app=client_app, + backend_name=backend_name, + backend_config_json_stream=backend_config_stream, + app_dir=app_dir, + state_factory=state_factory, + f_stop=f_stop, + ) + + except Exception as ex: + log(ERROR, "An exception occurred !! %s", ex) + log(ERROR, traceback.format_exc()) + raise RuntimeError("An error was encountered. Ending simulation.") from ex + + finally: + # Stop Driver + driver_server.stop(grace=0) + del driver + # Trigger stop event + f_stop.set() + + event(EventType.RUN_SUPERLINK_LEAVE) + if serverapp_th: + serverapp_th.join() + + log(INFO, "Stopping Simulation Engine now.") + + +# pylint: disable=too-many-arguments,too-many-locals +def _run_simulation( + num_supernodes: int, + client_app: Optional[ClientApp] = None, + server_app: Optional[ServerApp] = None, + backend_name: str = "ray", + backend_config: Optional[Dict[str, ConfigsRecordValues]] = None, + client_app_attr: Optional[str] = None, + server_app_attr: Optional[str] = None, + app_dir: str = "", + driver_api_address: str = "0.0.0.0:9091", + enable_tf_gpu_growth: bool = False, + verbose_logging: bool = False, +) -> None: + r"""Launch the Simulation Engine. + + Parameters + ---------- + num_supernodes : int + Number of nodes that run a ClientApp. They can be sampled by a + Driver in the ServerApp and receive a Message describing what the ClientApp + should perform. + + client_app : Optional[ClientApp] + The `ClientApp` to be executed by each of the `SuperNodes`. It will receive + messages sent by the `ServerApp`. + + server_app : Optional[ServerApp] + The `ServerApp` to be executed. + + backend_name : str (default: ray) + A simulation backend that runs `ClientApp`s. + + backend_config : Optional[Dict[str, ConfigsRecordValues]] + 'A dictionary, e.g {"":, "":} to configure a + backend. Values supported in are those included by + `flwr.common.typing.ConfigsRecordValues`. + + client_app_attr : str + A path to a `ClientApp` module to be loaded: For example: `client:app` or + `project.package.module:wrapper.app`." + + server_app_attr : str + A path to a `ServerApp` module to be loaded: For example: `server:app` or + `project.package.module:wrapper.app`." + + app_dir : str + Add specified directory to the PYTHONPATH and load `ClientApp` from there. + (Default: current working directory.) + + driver_api_address : str (default: "0.0.0.0:9091") + Driver API (gRPC) server address (IPv4, IPv6, or a domain name) + + enable_tf_gpu_growth : bool (default: False) + A boolean to indicate whether to enable GPU growth on the main thread. This is + desirable if you make use of a TensorFlow model on your `ServerApp` while + having your `ClientApp` running on the same GPU. Without enabling this, you + might encounter an out-of-memory error becasue TensorFlow by default allocates + all GPU memory. Read mor about how `tf.config.experimental.set_memory_growth()` + works in the TensorFlow documentation: https://www.tensorflow.org/api/stable. + + verbose_logging : bool (default: False) + When diabled, only INFO, WARNING and ERROR log messages will be shown. If + enabled, DEBUG-level logs will be displayed. + """ + # Set logging level + if not verbose_logging: + logger = logging.getLogger("flwr") + logger.setLevel(INFO) + + if backend_config is None: + backend_config = {} + + if enable_tf_gpu_growth: + # Check that Backend config has also enabled using GPU growth + use_tf = backend_config.get("tensorflow", False) + if not use_tf: + log(WARNING, "Enabling GPU growth for your backend.") + backend_config["tensorflow"] = True + + # Convert config to original JSON-stream format + backend_config_stream = json.dumps(backend_config) + + simulation_engine_th = None + args = ( + num_supernodes, + backend_name, + backend_config_stream, + driver_api_address, + app_dir, + enable_tf_gpu_growth, + client_app, + client_app_attr, + server_app, + server_app_attr, + ) + # Detect if there is an Asyncio event loop already running. + # If yes, run everything on a separate thread. In environmnets + # like Jupyter/Colab notebooks, there is an event loop present. + run_in_thread = False + try: + _ = ( + asyncio.get_running_loop() + ) # Raises RuntimeError if no event loop is present + log(DEBUG, "Asyncio event loop already running.") + + run_in_thread = True + + except RuntimeError: + log(DEBUG, "No asyncio event loop runnig") + + finally: + if run_in_thread: + log(DEBUG, "Starting Simulation Engine on a new thread.") + simulation_engine_th = threading.Thread(target=_main_loop, args=args) + simulation_engine_th.start() + simulation_engine_th.join() + else: + log(DEBUG, "Starting Simulation Engine on the main thread.") + _main_loop(*args) + + +def _parse_args_run_simulation() -> argparse.ArgumentParser: + """Parse flower-simulation command line arguments.""" + parser = argparse.ArgumentParser( + description="Start a Flower simulation", + ) + parser.add_argument( + "--server-app", + required=True, + help="For example: `server:app` or `project.package.module:wrapper.app`", + ) + parser.add_argument( + "--client-app", + required=True, + help="For example: `client:app` or `project.package.module:wrapper.app`", + ) + parser.add_argument( + "--num-supernodes", + type=int, + required=True, + help="Number of simulated SuperNodes.", + ) + parser.add_argument( + "--driver-api-address", + default="0.0.0.0:9091", + type=str, + help="For example: `server:app` or `project.package.module:wrapper.app`", + ) + parser.add_argument( + "--backend", + default="ray", + type=str, + help="Simulation backend that executes the ClientApp.", + ) + parser.add_argument( + "--backend-config", + type=str, + default='{"client_resources": {"num_cpus":2, "num_gpus":0.0}, "tensorflow": 0}', + help='A JSON formatted stream, e.g \'{"":, "":}\' to ' + "configure a backend. Values supported in are those included by " + "`flwr.common.typing.ConfigsRecordValues`. ", + ) + parser.add_argument( + "--enable-tf-gpu-growth", + action="store_true", + help="Enables GPU growth on the main thread. This is desirable if you make " + "use of a TensorFlow model on your `ServerApp` while having your `ClientApp` " + "running on the same GPU. Without enabling this, you might encounter an " + "out-of-memory error because TensorFlow by default allocates all GPU memory." + "Read more about how `tf.config.experimental.set_memory_growth()` works in " + "the TensorFlow documentation: https://www.tensorflow.org/api/stable.", + ) + parser.add_argument( + "--verbose", + action="store_true", + help="When unset, only INFO, WARNING and ERROR log messages will be shown. " + "If set, DEBUG-level logs will be displayed. ", + ) + parser.add_argument( + "--app-dir", + default="", + help="Add specified directory to the PYTHONPATH and load" + "ClientApp and ServerApp from there." + " Default: current working directory.", + ) + + return parser diff --git a/src/py/flwr_example/tensorflow_fashion_mnist/fashion_mnist_test.py b/src/py/flwr_example/tensorflow_fashion_mnist/fashion_mnist_test.py index 1213410fbc34..f6b922b27eab 100644 --- a/src/py/flwr_example/tensorflow_fashion_mnist/fashion_mnist_test.py +++ b/src/py/flwr_example/tensorflow_fashion_mnist/fashion_mnist_test.py @@ -21,7 +21,7 @@ def test_shuffle() -> None: - """Test if shuffle is deterministic depending on the the provided seed.""" + """Test if shuffle is deterministic depending on the provided seed.""" # Prepare x_tt = np.arange(8) y_tt = np.arange(8) diff --git a/src/py/flwr_experimental/baseline/config/config.py b/src/py/flwr_experimental/baseline/config/config.py index 5170ea7f7e26..16c144bb6a2f 100644 --- a/src/py/flwr_experimental/baseline/config/config.py +++ b/src/py/flwr_experimental/baseline/config/config.py @@ -23,7 +23,7 @@ from flwr_experimental.ops.instance import Instance # We assume that devices which are older will have at most -# ~80% of the the Samsung Galaxy Note 5 compute performance. +# ~80% of the Samsung Galaxy Note 5 compute performance. SCORE_MISSING = int(226 * 0.80) DEVICE_DISTRIBUTION = [ diff --git a/src/py/flwr_experimental/baseline/tf_cifar/settings.py b/src/py/flwr_experimental/baseline/tf_cifar/settings.py index 829856af0ace..ed1a72cafac9 100644 --- a/src/py/flwr_experimental/baseline/tf_cifar/settings.py +++ b/src/py/flwr_experimental/baseline/tf_cifar/settings.py @@ -120,9 +120,9 @@ def configure_clients( cid=str(i), partition=i, # Indices 0 to 49 fast, 50 to 99 slow - delay_factor=delay_factor_fast - if i < int(num_clients / 2) - else delay_factor_slow, + delay_factor=( + delay_factor_fast if i < int(num_clients / 2) else delay_factor_slow + ), # Shared iid_fraction=iid_fraction, num_clients=num_clients, diff --git a/src/py/flwr_experimental/baseline/tf_fashion_mnist/settings.py b/src/py/flwr_experimental/baseline/tf_fashion_mnist/settings.py index b0de9841de30..72adc1f0be04 100644 --- a/src/py/flwr_experimental/baseline/tf_fashion_mnist/settings.py +++ b/src/py/flwr_experimental/baseline/tf_fashion_mnist/settings.py @@ -141,9 +141,9 @@ def configure_clients( cid=str(i), partition=i, # Indices 0 to 49 fast, 50 to 99 slow - delay_factor=delay_factor_fast - if i < int(num_clients / 2) - else delay_factor_slow, + delay_factor=( + delay_factor_fast if i < int(num_clients / 2) else delay_factor_slow + ), # Shared iid_fraction=iid_fraction, num_clients=num_clients, diff --git a/src/py/flwr_experimental/baseline/tf_hotkey/settings.py b/src/py/flwr_experimental/baseline/tf_hotkey/settings.py index 18c58d1333a5..5bfb7b1e42ad 100644 --- a/src/py/flwr_experimental/baseline/tf_hotkey/settings.py +++ b/src/py/flwr_experimental/baseline/tf_hotkey/settings.py @@ -144,9 +144,9 @@ def configure_clients( cid=str(i), partition=i, # Indices 0 to 49 fast, 50 to 99 slow - delay_factor=delay_factor_fast - if i < int(num_clients / 2) - else delay_factor_slow, + delay_factor=( + delay_factor_fast if i < int(num_clients / 2) else delay_factor_slow + ), # Shared iid_fraction=iid_fraction, num_clients=num_clients, diff --git a/src/py/flwr_experimental/logserver/server.py b/src/py/flwr_experimental/logserver/server.py index dc729cebbf33..683b12b6db6c 100644 --- a/src/py/flwr_experimental/logserver/server.py +++ b/src/py/flwr_experimental/logserver/server.py @@ -69,9 +69,9 @@ def upload_file(local_filepath: str, s3_key: Optional[str]) -> None: Bucket=CONFIG["s3_bucket"], Key=s3_key, ExtraArgs={ - "ContentType": "application/pdf" - if s3_key.endswith(".pdf") - else "text/plain" + "ContentType": ( + "application/pdf" if s3_key.endswith(".pdf") else "text/plain" + ) }, ) # pylint: disable=broad-except diff --git a/src/py/flwr_experimental/ops/cluster.py b/src/py/flwr_experimental/ops/cluster.py index dfd0dd13727e..53a4e9617427 100644 --- a/src/py/flwr_experimental/ops/cluster.py +++ b/src/py/flwr_experimental/ops/cluster.py @@ -73,10 +73,10 @@ def ssh_connection( username, key_filename = ssh_credentials instance_ssh_port: int = cast(int, instance.ssh_port) - ignore_host_key_policy: Union[ - Type[MissingHostKeyPolicy], MissingHostKeyPolicy - ] = cast( - Union[Type[MissingHostKeyPolicy], MissingHostKeyPolicy], IgnoreHostKeyPolicy + ignore_host_key_policy: Union[Type[MissingHostKeyPolicy], MissingHostKeyPolicy] = ( + cast( + Union[Type[MissingHostKeyPolicy], MissingHostKeyPolicy], IgnoreHostKeyPolicy + ) ) client = SSHClient() @@ -139,7 +139,7 @@ def group_instances_by_specs(instances: List[Instance]) -> List[List[Instance]]: class Cluster: - """Compute enviroment independend compute cluster.""" + """Compute environment independend compute cluster.""" def __init__( self, diff --git a/src/py/flwr_tool/protoc_test.py b/src/py/flwr_tool/protoc_test.py index 607d808c8497..2d48582eb441 100644 --- a/src/py/flwr_tool/protoc_test.py +++ b/src/py/flwr_tool/protoc_test.py @@ -28,4 +28,4 @@ def test_directories() -> None: def test_proto_file_count() -> None: """Test if the correct number of proto files were captured by the glob.""" - assert len(PROTO_FILES) == 6 + assert len(PROTO_FILES) == 7 diff --git a/src/py/flwr_tool/update_changelog.py b/src/py/flwr_tool/update_changelog.py index bbd5c7f3dc7b..a158cca21765 100644 --- a/src/py/flwr_tool/update_changelog.py +++ b/src/py/flwr_tool/update_changelog.py @@ -62,7 +62,7 @@ def _extract_changelog_entry(pr_info): f"{CHANGELOG_SECTION_HEADER}(.+?)(?=##|$)", pr_info.body, re.DOTALL ) if not entry_match: - return None, "general" + return None, None entry_text = entry_match.group(1).strip()