diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index 61c8d14039db..03849995735d 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -31,7 +31,7 @@ RUN apt-get install -y curl wget gnupg python3 python-is-python3 python3-pip git RUN python -m pip install \ pip==23.3.1 \ setuptools==68.2.2 \ - poetry==1.5.1 + poetry==1.7.1 USER $USERNAME ENV PATH="/home/$USERNAME/.local/bin:${PATH}" diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index a5eadadf8604..34af632814a3 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -3,5 +3,17 @@ # Default code owners * @danieljanes @tanertopal +# README.md +README.md @jafermarq @tanertopal @danieljanes + # Flower Baselines /baselines @jafermarq @tanertopal @danieljanes + +# Flower Datasets +/datasets @jafermarq @tanertopal @danieljanes + +# Flower Examples +/examples @jafermarq @tanertopal @danieljanes + +# Changelog +/doc/source/ref-changelog.md @jafermarq @tanertopal @danieljanes diff --git a/.github/ISSUE_TEMPLATE/baseline_request.yml b/.github/ISSUE_TEMPLATE/baseline_request.yml index 49ae922a94ad..8df8eec22a21 100644 --- a/.github/ISSUE_TEMPLATE/baseline_request.yml +++ b/.github/ISSUE_TEMPLATE/baseline_request.yml @@ -38,18 +38,18 @@ body: attributes: label: For first time contributors value: | - - [ ] Read the [`first contribution` doc](https://flower.dev/docs/first-time-contributors.html) + - [ ] Read the [`first contribution` doc](https://flower.ai/docs/first-time-contributors.html) - [ ] Complete the Flower tutorial - [ ] Read the Flower Baselines docs to get an overview: - - [ ] [How to use Flower Baselines](https://flower.dev/docs/baselines/how-to-use-baselines.html) - - [ ] [How to contribute a Flower Baseline](https://flower.dev/docs/baselines/how-to-contribute-baselines.html) + - [ ] [How to use Flower Baselines](https://flower.ai/docs/baselines/how-to-use-baselines.html) + - [ ] [How to contribute a Flower Baseline](https://flower.ai/docs/baselines/how-to-contribute-baselines.html) - type: checkboxes attributes: label: Prepare - understand the scope options: - label: Read the paper linked above - label: Decide which experiments you'd like to reproduce. The more the better! - - label: Follow the steps outlined in [Add a new Flower Baseline](https://flower.dev/docs/baselines/how-to-contribute-baselines.html#add-a-new-flower-baseline). + - label: Follow the steps outlined in [Add a new Flower Baseline](https://flower.ai/docs/baselines/how-to-contribute-baselines.html#add-a-new-flower-baseline). - label: You can use as reference [other baselines](https://github.com/adap/flower/tree/main/baselines) that the community merged following those steps. - type: checkboxes attributes: diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index d0ab46d60b57..9e0ed23d4dbb 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -1,14 +1,14 @@ blank_issues_enabled: false contact_links: - name: Slack Channel - url: https://flower.dev/join-slack + url: https://flower.ai/join-slack about: Connect with other Flower users and contributors and discuss with them or ask them questions. - name: Discussion url: https://github.com/adap/flower/discussions - about: Ask about new features or general questions. Please use the discussion area in most of the cases instead of the issues. + about: Ask about new features or general questions. Please use the discussion area in most of the cases instead of the issues. - name: Flower Issues url: https://github.com/adap/flower/issues about: Contribute new features/enhancements, report bugs, or improve the documentation. - name: Flower Mail - url: https://flower.dev/ - about: If your project needs professional support please contact the Flower team (hello@flower.dev). + url: https://flower.ai/ + about: If your project needs professional support please contact the Flower team (hello@flower.ai). diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 8d73ed618919..0077bbab0909 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -36,10 +36,29 @@ Example: The variable `rnd` was renamed to `server_round` to improve readability - [ ] Implement proposed change - [ ] Write tests -- [ ] Update [documentation](https://flower.dev/docs/writing-documentation.html) -- [ ] Update [changelog](https://github.com/adap/flower/blob/main/doc/source/changelog.rst) +- [ ] Update [documentation](https://flower.ai/docs/writing-documentation.html) +- [ ] Update the changelog entry below - [ ] Make CI checks pass -- [ ] Ping maintainers on [Slack](https://flower.dev/join-slack/) (channel `#contributions`) +- [ ] Ping maintainers on [Slack](https://flower.ai/join-slack/) (channel `#contributions`) + + + +### Changelog entry + + ### Any other comments? @@ -49,7 +68,7 @@ Smaller PRs with good descriptions can be considered much more easily. If you have an urgent request or question, please use the Flower Slack: - https://flower.dev/join-slack/ (channel: #contributions) + https://flower.ai/join-slack/ (channel: #contributions) Thank you for contributing to Flower! --> diff --git a/.github/actions/bootstrap/action.yml b/.github/actions/bootstrap/action.yml index bab5113bc567..7b1716bbc954 100644 --- a/.github/actions/bootstrap/action.yml +++ b/.github/actions/bootstrap/action.yml @@ -12,7 +12,7 @@ inputs: default: 68.2.2 poetry-version: description: "Version of poetry to be installed using pip" - default: 1.5.1 + default: 1.7.1 outputs: python-version: description: "Version range or exact version of Python or PyPy" @@ -30,7 +30,7 @@ runs: using: "composite" steps: - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ inputs.python-version }} - name: Install build tools diff --git a/.github/workflows/_docker-build.yml b/.github/workflows/_docker-build.yml index d36076444746..99ec0671db66 100644 --- a/.github/workflows/_docker-build.yml +++ b/.github/workflows/_docker-build.yml @@ -1,4 +1,4 @@ -name: Reusable docker server image build workflow +name: Reusable docker image build workflow on: workflow_call: @@ -35,19 +35,32 @@ permissions: # based on https://docs.docker.com/build/ci/github-actions/multi-platform/#distribute-build-across-multiple-runners jobs: build: - name: Build server image + name: Build image runs-on: ubuntu-22.04 timeout-minutes: 60 + outputs: + build-id: ${{ steps.build-id.outputs.id }} strategy: fail-fast: true matrix: platform: [ # build-push action and qemu use different platform names # therefore we create a map - { qemu: "", docker: "linux/amd64" }, - { qemu: "arm64", docker: "linux/arm64" }, + { name: "amd64", qemu: "", docker: "linux/amd64" }, + { name: "arm64", qemu: "arm64", docker: "linux/arm64" }, ] steps: + - name: Create build id + id: build-id + shell: python + run: | + import hashlib + import os + + hash = hashlib.sha256('''${{ inputs.build-args }}'''.encode()) + with open(os.environ['GITHUB_OUTPUT'], 'a') as fh: + print(f"id={hash.hexdigest()}", file=fh) + - name: Set up QEMU if: matrix.platform.qemu != '' uses: docker/setup-qemu-action@68827325e0b33c7199eb31dd4e31fbe9023e06e3 # v3.0.0 @@ -85,9 +98,9 @@ jobs: touch "/tmp/digests/${digest#sha256:}" - name: Upload digest - uses: actions/upload-artifact@a8a3f3ad30e3422c9c7b888a15615d19a852ae32 # v3.1.3 + uses: actions/upload-artifact@1eb3cb2b3e0f29609092a73eb033bb759a334595 # v4.1.0 with: - name: digests + name: digests-${{ steps.build-id.outputs.id }}-${{ matrix.platform.name }} path: /tmp/digests/* if-no-files-found: error retention-days: 1 @@ -101,10 +114,11 @@ jobs: metadata: ${{ steps.meta.outputs.json }} steps: - name: Download digests - uses: actions/download-artifact@9bc31d5ccc31df68ecc42ccf4149144866c47d8a # v3.0.2 + uses: actions/download-artifact@eaceaf801fd36c7dee90939fad912460b18a1ffe # v4.1.2 with: - name: digests + pattern: digests-${{ needs.build.outputs.build-id }}-* path: /tmp/digests + merge-multiple: true - name: Docker meta id: meta diff --git a/.github/workflows/baselines.yml b/.github/workflows/baselines.yml index bfb26053836d..c4485fe72d10 100644 --- a/.github/workflows/baselines.yml +++ b/.github/workflows/baselines.yml @@ -1,14 +1,5 @@ name: Baselines -# The aim of this workflow is to test only the changed (or added) baseline. -# Here is the rough idea of how it works (more details are presented later in the comments): -# 1. Checks for the changes between the current branch and the main - in case of PR - -# or between the HEAD and HEAD~1 (main last commit and the previous one) - in case of -# a push to main. -# 2. Fails the test if there are changes to more than one baseline. Passes the test -# (skips the rests) if there are no changes to any baselines. Follows the test if only -# one baseline is added or modified. -# 3. Sets up the env specified for the baseline. -# 4. Runs the tests. + on: push: branches: @@ -24,112 +15,75 @@ concurrency: env: FLWR_TELEMETRY_ENABLED: 0 -defaults: - run: - working-directory: baselines - jobs: - test_baselines: - name: Test + changes: runs-on: ubuntu-22.04 + permissions: + pull-requests: read + outputs: + baselines: ${{ steps.filter.outputs.changes }} steps: - uses: actions/checkout@v4 - # The depth two of the checkout is needed in case of merging to the main - # because we compare the HEAD (current version) with HEAD~1 (version before - # the PR was merged) - with: - fetch-depth: 2 - - name: Fetch main branch - run: | - # The main branch is needed in case of the PR to make a comparison (by - # default the workflow takes as little information as possible - it does not - # have the history - if [ ${{ github.event_name }} == "pull_request" ] - then - git fetch origin main:main - fi - - name: Find changed/new baselines - id: find_changed_baselines_dirs + + - shell: bash run: | - if [ ${{ github.event_name }} == "push" ] - then - # Push event triggered when merging to main - change_references="HEAD..HEAD~1" - else - # Pull request event triggered for any commit to a pull request - change_references="main..HEAD" - fi - dirs=$(git diff --dirstat=files,0 ${change_references} . | awk '{print $2}' | grep -E '^baselines/[^/]*/$' | \ - grep -v \ - -e '^baselines/dev' \ - -e '^baselines/baseline_template' \ - -e '^baselines/flwr_baselines' \ - -e '^baselines/doc' \ - | sed 's/^baselines\///') - # git diff --dirstat=files,0 ${change_references} . - checks the differences - # and a file is counted as changed if more than 0 lines were changed - # it returns the results in the format x.y% path/to/dir/ - # awk '{print $2}' - takes only the directories (skips the percentages) - # grep -E '^baselines/[^/]*/$' - takes only the paths that start with - # baseline (and have at least one subdirectory) - # grep -v -e ... - excludes the `baseline_template`, `dev`, `flwr_baselines` - # sed 's/^baselines\///' - narrows down the path to baseline/ - echo "Detected changed directories: ${dirs}" - # Save changed dirs to output of this step - EOF=$(dd if=/dev/urandom bs=15 count=1 status=none | base64) - echo "dirs<> "$GITHUB_OUTPUT" - for dir in $dirs - do - echo "$dir" >> "$GITHUB_OUTPUT" - done - echo "EOF" >> "$GITHUB_OUTPUT" - - name: Validate changed/new baselines - id: validate_changed_baselines_dirs + # create a list of all directories in baselines + { + echo 'FILTER_PATHS<> "$GITHUB_ENV" + + - uses: dorny/paths-filter@v3 + id: filter + with: + filters: ${{ env.FILTER_PATHS }} + + - if: ${{ github.event.pull_request.head.repo.fork }} run: | - dirs="${{ steps.find_changed_baselines_dirs.outputs.dirs }}" - dirs_array=() - if [[ -n $dirs ]]; then - while IFS= read -r line; do - dirs_array+=("$line") - done <<< "$dirs" - fi - length=${#dirs_array[@]} - echo "The number of changed baselines is $length" - - if [ $length -gt 1 ]; then - echo "The changes should only apply to a single baseline" - exit 1 + CHANGES=$(echo "${{ toJson(steps.filter.outputs.changes) }}" | jq '. | length') + if [ "$CHANGES" -gt 1 ]; then + echo "::error ::The changes should only apply to a single baseline." + exit 1 fi - - if [ $length -eq 0 ]; then - echo "The baselines were not changed - skipping the remaining steps." - echo "baseline_changed=false" >> "$GITHUB_OUTPUT" - exit 0 - fi - - echo "changed_dir=${dirs[0]}" >> "$GITHUB_OUTPUT" - echo "baseline_changed=true" >> "$GITHUB_OUTPUT" + + test: + runs-on: ubuntu-22.04 + needs: changes + if: ${{ needs.changes.outputs.baselines != '' && toJson(fromJson(needs.changes.outputs.baselines)) != '[]' }} + strategy: + matrix: + baseline: ${{ fromJSON(needs.changes.outputs.baselines) }} + steps: + - uses: actions/checkout@v4 + - name: Bootstrap - if: steps.validate_changed_baselines_dirs.outputs.baseline_changed == 'true' uses: ./.github/actions/bootstrap with: python-version: '3.10' + - name: Install dependencies - if: steps.validate_changed_baselines_dirs.outputs.baseline_changed == 'true' - run: | - changed_dir="${{ steps.validate_changed_baselines_dirs.outputs.changed_dir }}" - cd "${changed_dir}" - python -m poetry install - - name: Test - if: steps.validate_changed_baselines_dirs.outputs.baseline_changed == 'true' - run: | - dir="${{ steps.validate_changed_baselines_dirs.outputs.changed_dir }}" - echo "Testing ${dir}" - ./dev/test-baseline.sh $dir - - name: Test Structure - if: steps.validate_changed_baselines_dirs.outputs.baseline_changed == 'true' - run: | - dir="${{ steps.validate_changed_baselines_dirs.outputs.changed_dir }}" - echo "Testing ${dir}" - ./dev/test-baseline-structure.sh $dir + working-directory: baselines/${{ matrix.baseline }} + run: python -m poetry install + + - name: Testing ${{ matrix.baseline }} + working-directory: baselines + run: ./dev/test-baseline.sh ${{ matrix.baseline }} + - name: Test Structure of ${{ matrix.baseline }} + working-directory: baselines + run: ./dev/test-baseline-structure.sh ${{ matrix.baseline }} diff --git a/.github/workflows/cpp.yml b/.github/workflows/cpp.yml index a1d918aa5b29..7efc879bcee2 100644 --- a/.github/workflows/cpp.yml +++ b/.github/workflows/cpp.yml @@ -19,13 +19,13 @@ jobs: uses: ./.github/actions/bootstrap - name: Cache restore SDK build - uses: actions/cache/restore@v3 + uses: actions/cache/restore@v4 with: path: build/ key: ${{ runner.os }}-sdk-build - name: Cache restore example build - uses: actions/cache/restore@v3 + uses: actions/cache/restore@v4 with: path: examples/quickstart-cpp/build/ key: ${{ runner.os }}-example-build @@ -82,14 +82,14 @@ jobs: fi - name: Cache save SDK build - uses: actions/cache/save@v3 + uses: actions/cache/save@v4 if: github.ref_name == 'main' with: path: build/ key: ${{ runner.os }}-sdk-build - name: Cache save example build - uses: actions/cache/save@v3 + uses: actions/cache/save@v4 if: github.ref_name == 'main' with: path: examples/quickstart-cpp/build/ diff --git a/.github/workflows/datasets-e2e.yml b/.github/workflows/datasets-e2e.yml new file mode 100644 index 000000000000..2a73a8538b14 --- /dev/null +++ b/.github/workflows/datasets-e2e.yml @@ -0,0 +1,52 @@ +name: Datasets-E2E + +on: + push: + branches: + - main + paths: + - "datasets/flwr_datasets/**" + pull_request: + branches: + - main + paths: + - "datasets/flwr_datasets/**" + +concurrency: + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +env: + FLWR_TELEMETRY_ENABLED: 0 + +jobs: + frameworks: + runs-on: ubuntu-22.04 + timeout-minutes: 10 + # Using approach described here: + # https://docs.github.com/en/actions/using-jobs/using-a-matrix-for-your-jobs + strategy: + matrix: + include: + - directory: pytorch + + - directory: tensorflow + + - directory: scikit-learn + + name: Framework / ${{matrix.directory}} + + defaults: + run: + working-directory: datasets/e2e/${{ matrix.directory }} + + steps: + - uses: actions/checkout@v4 + - name: Bootstrap + uses: ./.github/actions/bootstrap + with: + python-version: 3.8 + - name: Install dependencies + run: python -m poetry install + - name: Run tests + run: python -m unittest discover -p '*_test.py' diff --git a/.github/workflows/docker-base.yml b/.github/workflows/docker-base.yml index 7b23340ec7d2..f2cd2ef99d08 100644 --- a/.github/workflows/docker-base.yml +++ b/.github/workflows/docker-base.yml @@ -39,13 +39,13 @@ jobs: echo "ubuntu-version=${{ env.DEFAULT_UBUNTU }}" >> "$GITHUB_OUTPUT" build-base-images: - name: Build images + name: Build base images uses: ./.github/workflows/_docker-build.yml needs: parameters strategy: - fail-fast: true + fail-fast: false matrix: - python-version: [3.8, 3.9, 3.10, 3.11] + python-version: ["3.8", "3.9", "3.10", "3.11"] with: namespace-repository: flwr/base file-dir: src/docker/base diff --git a/.github/workflows/docker-client.yml b/.github/workflows/docker-client.yml new file mode 100644 index 000000000000..3c2d83596733 --- /dev/null +++ b/.github/workflows/docker-client.yml @@ -0,0 +1,36 @@ +name: Build docker client image + +on: + workflow_dispatch: + inputs: + flwr-version: + description: "Version of Flower e.g. (1.7.0)." + required: true + type: string + +permissions: + contents: read + +jobs: + build-client-images: + name: Build client images + uses: ./.github/workflows/_docker-build.yml + # run only on default branch when using it with workflow_dispatch + if: github.ref_name == github.event.repository.default_branch + strategy: + fail-fast: false + matrix: + python-version: ["3.8", "3.9", "3.10", "3.11"] + with: + namespace-repository: flwr/client + file-dir: src/docker/client + build-args: | + FLWR_VERSION=${{ github.event.inputs.flwr-version }} + BASE_IMAGE_TAG=py${{ matrix.python-version }}-ubuntu22.04 + tags: | + ${{ github.event.inputs.flwr-version }}-py${{ matrix.python-version }}-ubuntu22.04 + ${{ github.event.inputs.flwr-version }} + latest + secrets: + dockerhub-user: ${{ secrets.DOCKERHUB_USERNAME }} + dockerhub-token: ${{ secrets.DOCKERHUB_TOKEN }} diff --git a/.github/workflows/docker-server.yml b/.github/workflows/docker-server.yml index 093f148c4c34..1e43715207d4 100644 --- a/.github/workflows/docker-server.yml +++ b/.github/workflows/docker-server.yml @@ -4,11 +4,11 @@ on: workflow_dispatch: inputs: flwr-version: - description: "Version of Flower e.g. (1.6.0)." + description: "Version of Flower e.g. (1.7.0)." required: true type: string - base-image-version: - description: "Version of the Flower base image." + base-image-tag: + description: "The tag of the Flower base image." required: false type: string default: "py3.11-ubuntu22.04" @@ -20,14 +20,16 @@ jobs: build-server-images: name: Build images uses: ./.github/workflows/_docker-build.yml + # run only on default branch when using it with workflow_dispatch + if: github.ref_name == github.event.repository.default_branch with: namespace-repository: flwr/server file-dir: src/docker/server build-args: | FLWR_VERSION=${{ github.event.inputs.flwr-version }} - BASE_IMAGE_VERSION=${{ github.event.inputs.base-image-version }} + BASE_IMAGE_TAG=${{ github.event.inputs.base-image-tag }} tags: | - ${{ github.event.inputs.flwr-version }}-${{ github.event.inputs.base-image-version }} + ${{ github.event.inputs.flwr-version }}-${{ github.event.inputs.base-image-tag }} ${{ github.event.inputs.flwr-version }} latest secrets: diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 52944ffecf70..a4c769fdf850 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -23,6 +23,8 @@ jobs: - uses: actions/checkout@v4 with: fetch-depth: 0 + - name: Check copyright line + run: ./dev/test-copyright.sh - name: Bootstrap uses: ./.github/actions/bootstrap - name: Install pandoc @@ -41,8 +43,9 @@ jobs: AWS_DEFAULT_REGION: ${{ secrets. AWS_DEFAULT_REGION }} AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets. AWS_SECRET_ACCESS_KEY }} + DOCS_BUCKET: flower.ai run: | - aws s3 sync --delete --exclude ".*" --exclude "v/*" --cache-control "no-cache" ./doc/build/html/ s3://flower.dev/docs/framework - aws s3 sync --delete --exclude ".*" --exclude "v/*" --cache-control "no-cache" ./baselines/doc/build/html/ s3://flower.dev/docs/baselines - aws s3 sync --delete --exclude ".*" --exclude "v/*" --cache-control "no-cache" ./examples/doc/build/html/ s3://flower.dev/docs/examples - aws s3 sync --delete --exclude ".*" --exclude "v/*" --cache-control "no-cache" ./datasets/doc/build/html/ s3://flower.dev/docs/datasets + aws s3 sync --delete --exclude ".*" --exclude "v/*" --cache-control "no-cache" ./doc/build/html/ s3://${{ env.DOCS_BUCKET }}/docs/framework + aws s3 sync --delete --exclude ".*" --exclude "v/*" --cache-control "no-cache" ./baselines/doc/build/html/ s3://${{ env.DOCS_BUCKET }}/docs/baselines + aws s3 sync --delete --exclude ".*" --exclude "v/*" --cache-control "no-cache" ./examples/doc/build/html/ s3://${{ env.DOCS_BUCKET }}/docs/examples + aws s3 sync --delete --exclude ".*" --exclude "v/*" --cache-control "no-cache" ./datasets/doc/build/html/ s3://${{ env.DOCS_BUCKET }}/docs/datasets diff --git a/.github/workflows/e2e.yml b/.github/workflows/e2e.yml index a5121ad71b38..62f3c0a78ce4 100644 --- a/.github/workflows/e2e.yml +++ b/.github/workflows/e2e.yml @@ -14,6 +14,7 @@ concurrency: env: FLWR_TELEMETRY_ENABLED: 0 + ARTIFACT_BUCKET: artifact.flower.ai jobs: wheel: @@ -30,12 +31,12 @@ jobs: - name: Test wheel run: ./dev/test-wheel.sh - name: Upload wheel - if: ${{ github.repository == 'adap/flower' && !github.event.pull_request.head.repo.fork }} + if: ${{ github.repository == 'adap/flower' && !github.event.pull_request.head.repo.fork && github.actor != 'dependabot[bot]' }} id: upload env: - AWS_DEFAULT_REGION: ${{ secrets. AWS_DEFAULT_REGION }} + AWS_DEFAULT_REGION: ${{ secrets.AWS_DEFAULT_REGION }} AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} - AWS_SECRET_ACCESS_KEY: ${{ secrets. AWS_SECRET_ACCESS_KEY }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} run: | cd ./dist echo "WHL_PATH=$(ls *.whl)" >> "$GITHUB_OUTPUT" @@ -43,7 +44,8 @@ jobs: echo "SHORT_SHA=$sha_short" >> "$GITHUB_OUTPUT" [ -z "${{ github.head_ref }}" ] && dir="${{ github.ref_name }}" || dir="pr/${{ github.head_ref }}" echo "DIR=$dir" >> "$GITHUB_OUTPUT" - aws s3 cp --content-disposition "attachment" --cache-control "no-cache" ./ s3://artifact.flower.dev/py/$dir/$sha_short --recursive + aws s3 cp --content-disposition "attachment" --cache-control "no-cache" ./ s3://${{ env.ARTIFACT_BUCKET }}/py/$dir/$sha_short --recursive + aws s3 cp --content-disposition "attachment" --cache-control "no-cache" ./ s3://${{ env.ARTIFACT_BUCKET }}/py/$dir/latest --recursive outputs: whl_path: ${{ steps.upload.outputs.WHL_PATH }} short_sha: ${{ steps.upload.outputs.SHORT_SHA }} @@ -73,7 +75,7 @@ jobs: dataset: | import tensorflow as tf tf.keras.datasets.cifar10.load_data() - + - directory: tabnet dataset: | import tensorflow_datasets as tfds @@ -83,17 +85,12 @@ jobs: dataset: | from torchvision.datasets import CIFAR10 CIFAR10('./data', download=True) - + - directory: pytorch-lightning dataset: | from torchvision.datasets import MNIST MNIST('./data', download=True) - - directory: mxnet - dataset: | - import mxnet as mx - mx.test_utils.get_mnist() - - directory: scikit-learn dataset: | import openml @@ -102,7 +99,7 @@ jobs: - directory: fastai dataset: | from fastai.vision.all import untar_data, URLs - untar_data(URLs.MNIST) + untar_data(URLs.MNIST) - directory: pandas dataset: | @@ -126,9 +123,9 @@ jobs: - name: Install dependencies run: python -m poetry install - name: Install Flower wheel from artifact store - if: ${{ github.repository == 'adap/flower' && !github.event.pull_request.head.repo.fork }} + if: ${{ github.repository == 'adap/flower' && !github.event.pull_request.head.repo.fork && github.actor != 'dependabot[bot]' }} run: | - python -m pip install https://artifact.flower.dev/py/${{ needs.wheel.outputs.dir }}/${{ needs.wheel.outputs.short_sha }}/${{ needs.wheel.outputs.whl_path }} + python -m pip install https://${{ env.ARTIFACT_BUCKET }}/py/${{ needs.wheel.outputs.dir }}/${{ needs.wheel.outputs.short_sha }}/${{ needs.wheel.outputs.whl_path }} - name: Download dataset if: ${{ matrix.dataset }} run: python -c "${{ matrix.dataset }}" @@ -138,6 +135,12 @@ jobs: run: python simulation.py - name: Run driver test run: ./../test_driver.sh "${{ matrix.directory }}" + - name: Run driver test with REST + if: ${{ matrix.directory == 'bare' }} + run: ./../test_driver.sh bare rest + - name: Run driver test with SQLite database + if: ${{ matrix.directory == 'bare' }} + run: ./../test_driver.sh bare sqlite strategies: runs-on: ubuntu-22.04 @@ -161,11 +164,11 @@ jobs: run: | python -m poetry install - name: Install Flower wheel from artifact store - if: ${{ github.repository == 'adap/flower' && !github.event.pull_request.head.repo.fork }} + if: ${{ github.repository == 'adap/flower' && !github.event.pull_request.head.repo.fork && github.actor != 'dependabot[bot]' }} run: | - python -m pip install https://artifact.flower.dev/py/${{ needs.wheel.outputs.dir }}/${{ needs.wheel.outputs.short_sha }}/${{ needs.wheel.outputs.whl_path }} + python -m pip install https://${{ env.ARTIFACT_BUCKET }}/py/${{ needs.wheel.outputs.dir }}/${{ needs.wheel.outputs.short_sha }}/${{ needs.wheel.outputs.whl_path }} - name: Cache Datasets - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: "~/.keras" key: keras-datasets diff --git a/.github/workflows/framework-draft-release.yml b/.github/workflows/framework-draft-release.yml new file mode 100644 index 000000000000..91a89953cf96 --- /dev/null +++ b/.github/workflows/framework-draft-release.yml @@ -0,0 +1,66 @@ +name: Draft release + +on: + push: + tags: + - "v*.*.*" + +env: + ARTIFACT_BUCKET: artifact.flower.ai + +jobs: + publish: + if: ${{ github.repository == 'adap/flower' }} + name: Publish draft + runs-on: ubuntu-22.04 + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Wait for wheel to be built + uses: lewagon/wait-on-check-action@v1.3.3 + with: + ref: ${{ github.ref }} + check-name: 'Build, test and upload wheel' + repo-token: ${{ secrets.GITHUB_TOKEN }} + wait-interval: 10 + - name: Download wheel + run: | + tag_name=$(echo "${GITHUB_REF_NAME}" | cut -c2-) + echo "TAG_NAME=$tag_name" >> "$GITHUB_ENV" + + wheel_name="flwr-${tag_name}-py3-none-any.whl" + echo "WHEEL_NAME=$wheel_name" >> "$GITHUB_ENV" + + tar_name="flwr-${tag_name}.tar.gz" + echo "TAR_NAME=$tar_name" >> "$GITHUB_ENV" + + wheel_url="https://${{ env.ARTIFACT_BUCKET }}/py/main/${GITHUB_SHA::7}/${wheel_name}" + tar_url="https://${{ env.ARTIFACT_BUCKET }}/py/main/${GITHUB_SHA::7}/${tar_name}" + + curl $wheel_url --output $wheel_name + curl $tar_url --output $tar_name + - name: Upload wheel + env: + AWS_DEFAULT_REGION: ${{ secrets. AWS_DEFAULT_REGION }} + AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets. AWS_SECRET_ACCESS_KEY }} + run: | + aws s3 cp --content-disposition "attachment" --cache-control "no-cache" ./${{ env.WHEEL_NAME }} s3://${{ env.ARTIFACT_BUCKET }}/py/release/v${{ env.TAG_NAME }}/${{ env.WHEEL_NAME }} + aws s3 cp --content-disposition "attachment" --cache-control "no-cache" ./${{ env.TAR_NAME }} s3://${{ env.ARTIFACT_BUCKET }}/py/release/v${{ env.TAG_NAME }}/${{ env.TAR_NAME }} + + - name: Generate body + run: | + ./dev/get-latest-changelog.sh > body.md + cat body.md + + - name: Release + uses: softprops/action-gh-release@v1 + with: + body_path: ./body.md + draft: true + name: Flower ${{ env.TAG_NAME }} + files: | + ${{ env.WHEEL_NAME }} + ${{ env.TAR_NAME }} diff --git a/.github/workflows/framework-release.yml b/.github/workflows/framework-release.yml index eab15a51d217..04b68fd38af9 100644 --- a/.github/workflows/framework-release.yml +++ b/.github/workflows/framework-release.yml @@ -1,63 +1,45 @@ -name: Release Framework +name: Publish `flwr` release on PyPI on: - push: - tags: - - "v*.*.*" + release: + types: [released] + +concurrency: + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +env: + ARTIFACT_BUCKET: artifact.flower.ai jobs: publish: if: ${{ github.repository == 'adap/flower' }} - name: Publish draft + name: Publish release runs-on: ubuntu-22.04 steps: - - name: Checkout - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - name: Wait for wheel to be built - uses: lewagon/wait-on-check-action@v1.3.1 - with: - ref: ${{ github.ref }} - check-name: 'Build, test and upload wheel' - repo-token: ${{ secrets.GITHUB_TOKEN }} - wait-interval: 10 - - name: Download wheel - run: | - tag_name=$(echo "${GITHUB_REF_NAME}" | cut -c2-) - echo "TAG_NAME=$tag_name" >> "$GITHUB_ENV" - - wheel_name="flwr-${tag_name}-py3-none-any.whl" - echo "WHEEL_NAME=$wheel_name" >> "$GITHUB_ENV" - - tar_name="flwr-${tag_name}.tar.gz" - echo "TAR_NAME=$tar_name" >> "$GITHUB_ENV" - - wheel_url="https://artifact.flower.dev/py/main/${GITHUB_SHA::7}/${wheel_name}" - tar_url="https://artifact.flower.dev/py/main/${GITHUB_SHA::7}/${tar_name}" - - curl $wheel_url --output $wheel_name - curl $tar_url --output $tar_name - - name: Upload wheel - env: - AWS_DEFAULT_REGION: ${{ secrets. AWS_DEFAULT_REGION }} - AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} - AWS_SECRET_ACCESS_KEY: ${{ secrets. AWS_SECRET_ACCESS_KEY }} - run: | - aws s3 cp --content-disposition "attachment" --cache-control "no-cache" ./${{ env.WHEEL_NAME }} s3://artifact.flower.dev/py/release/v${{ env.TAG_NAME }}/${{ env.WHEEL_NAME }} - aws s3 cp --content-disposition "attachment" --cache-control "no-cache" ./${{ env.TAR_NAME }} s3://artifact.flower.dev/py/release/v${{ env.TAG_NAME }}/${{ env.TAR_NAME }} - - - name: Generate body - run: | - ./dev/get-latest-changelog.sh > body.md - cat body.md - - - name: Release - uses: softprops/action-gh-release@v1 - with: - body_path: ./body.md - draft: true - name: Flower ${{ env.TAG_NAME }} - files: | - ${{ env.WHEEL_NAME }} - ${{ env.TAR_NAME }} + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Bootstrap + uses: ./.github/actions/bootstrap + + - name: Get artifacts and publish + env: + GITHUB_REF: ${{ github.ref }} + run: | + TAG_NAME=$(echo "${GITHUB_REF_NAME}" | cut -c2-) + + wheel_name="flwr-${TAG_NAME}-py3-none-any.whl" + tar_name="flwr-${TAG_NAME}.tar.gz" + + wheel_url="https://${{ env.ARTIFACT_BUCKET }}/py/release/v${TAG_NAME}/${wheel_name}" + tar_url="https://${{ env.ARTIFACT_BUCKET }}/py/release/v${TAG_NAME}/${tar_name}" + + mkdir -p dist + + curl $wheel_url --output dist/$wheel_name + curl $tar_url --output dist/$tar_name + + python -m poetry publish -u __token__ -p ${{ secrets.PYPI_TOKEN }} diff --git a/.github/workflows/framework.yml b/.github/workflows/framework.yml index 0899940636d2..784f04750c5e 100644 --- a/.github/workflows/framework.yml +++ b/.github/workflows/framework.yml @@ -25,7 +25,7 @@ jobs: # In case of a mismatch, the job has to download Python to install it. # Note: Due to a bug in actions/setup-python we have to put 3.10 in # qoutes as it will otherwise will assume 3.1 - python: [3.8, 3.9, '3.10'] + python: [3.8, 3.9, '3.10', '3.11'] name: Python ${{ matrix.python }} diff --git a/.github/workflows/swift.yml b/.github/workflows/swift.yml index 8758d0e1c5c7..2ca596a59361 100644 --- a/.github/workflows/swift.yml +++ b/.github/workflows/swift.yml @@ -20,7 +20,7 @@ jobs: name: Test runs-on: macos-latest steps: - - uses: fwal/setup-swift@f51889efb55dccf13be0ee727e3d6c89a096fb4c + - uses: fwal/setup-swift@cdbe0f7f4c77929b6580e71983e8606e55ffe7e4 with: swift-version: 5 - uses: actions/checkout@v4 @@ -31,7 +31,7 @@ jobs: runs-on: macos-latest name: Build docs steps: - - uses: fwal/setup-swift@f51889efb55dccf13be0ee727e3d6c89a096fb4c + - uses: fwal/setup-swift@cdbe0f7f4c77929b6580e71983e8606e55ffe7e4 with: swift-version: 5 - uses: actions/checkout@v4 @@ -44,7 +44,7 @@ jobs: runs-on: macos-latest name: Deploy docs steps: - - uses: fwal/setup-swift@f51889efb55dccf13be0ee727e3d6c89a096fb4c + - uses: fwal/setup-swift@cdbe0f7f4c77929b6580e71983e8606e55ffe7e4 with: swift-version: 5 - uses: actions/checkout@v4 diff --git a/.github/workflows/update-pr.yml b/.github/workflows/update-pr.yml index 78ef5bc86772..64b16aeabebf 100644 --- a/.github/workflows/update-pr.yml +++ b/.github/workflows/update-pr.yml @@ -9,7 +9,7 @@ jobs: runs-on: ubuntu-22.04 steps: - name: Automatically update mergeable PRs - uses: adRise/update-pr-branch@v0.7.0 + uses: adRise/update-pr-branch@cd305ecbd76bf63056c9400ce2c725293fc3e0c0 # v0.7.0 with: token: ${{ secrets.FLWRMACHINE_TOKEN }} base: 'main' diff --git a/CHANGELOG.md b/CHANGELOG.md index 264548c87669..1f01c0e9717d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,5 @@ # Changelog -Flower changes are tracked as part of the documentation: [Flower Changelog](https://flower.dev/docs/changelog.html). +Flower changes are tracked as part of the documentation: [Flower Changelog](https://flower.ai/docs/changelog.html). The changelog source can be edited here: `doc/source/changelog.rst` diff --git a/README.md b/README.md index b8b62e8c0c43..90faa2358fa6 100644 --- a/README.md +++ b/README.md @@ -1,16 +1,16 @@ # Flower: A Friendly Federated Learning Framework

- - Flower Website + + Flower Website

- Website | - Blog | - Docs | - Conference | - Slack + Website | + Blog | + Docs | + Conference | + Slack

@@ -18,7 +18,7 @@ [![PRs Welcome](https://img.shields.io/badge/PRs-welcome-brightgreen.svg)](https://github.com/adap/flower/blob/main/CONTRIBUTING.md) ![Build](https://github.com/adap/flower/actions/workflows/framework.yml/badge.svg) [![Downloads](https://static.pepy.tech/badge/flwr)](https://pepy.tech/project/flwr) -[![Slack](https://img.shields.io/badge/Chat-Slack-red)](https://flower.dev/join-slack) +[![Slack](https://img.shields.io/badge/Chat-Slack-red)](https://flower.ai/join-slack) Flower (`flwr`) is a framework for building federated learning systems. The design of Flower is based on a few guiding principles: @@ -33,14 +33,13 @@ design of Flower is based on a few guiding principles: - **Framework-agnostic**: Different machine learning frameworks have different strengths. Flower can be used with any machine learning framework, for - example, [PyTorch](https://pytorch.org), - [TensorFlow](https://tensorflow.org), [Hugging Face Transformers](https://huggingface.co/), [PyTorch Lightning](https://pytorchlightning.ai/), [MXNet](https://mxnet.apache.org/), [scikit-learn](https://scikit-learn.org/), [JAX](https://jax.readthedocs.io/), [TFLite](https://tensorflow.org/lite/), [fastai](https://www.fast.ai/), [Pandas](https://pandas.pydata.org/) for federated analytics, or even raw [NumPy](https://numpy.org/) + example, [PyTorch](https://pytorch.org), [TensorFlow](https://tensorflow.org), [Hugging Face Transformers](https://huggingface.co/), [PyTorch Lightning](https://pytorchlightning.ai/), [scikit-learn](https://scikit-learn.org/), [JAX](https://jax.readthedocs.io/), [TFLite](https://tensorflow.org/lite/), [MONAI](https://docs.monai.io/en/latest/index.html), [fastai](https://www.fast.ai/), [MLX](https://ml-explore.github.io/mlx/build/html/index.html), [XGBoost](https://xgboost.readthedocs.io/en/stable/), [Pandas](https://pandas.pydata.org/) for federated analytics, or even raw [NumPy](https://numpy.org/) for users who enjoy computing gradients by hand. - **Understandable**: Flower is written with maintainability in mind. The community is encouraged to both read and contribute to the codebase. -Meet the Flower community on [flower.dev](https://flower.dev)! +Meet the Flower community on [flower.ai](https://flower.ai)! ## Federated Learning Tutorial @@ -56,11 +55,11 @@ Flower's goal is to make federated learning accessible to everyone. This series 2. **Using Strategies in Federated Learning** - [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adap/flower/blob/main/doc/source/tutorial-use-a-federated-learning-strategy-pytorch.ipynb) (or open the [Jupyter Notebook](https://github.com/adap/flower/blob/main/doc/source/tutorial-use-a-federated-learning-strategy-pytorch.ipynb)) + [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adap/flower/blob/main/doc/source/tutorial-series-use-a-federated-learning-strategy-pytorch.ipynb) (or open the [Jupyter Notebook](https://github.com/adap/flower/blob/main/doc/source/tutorial-series-use-a-federated-learning-strategy-pytorch.ipynb)) 3. **Building Strategies for Federated Learning** - [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adap/flower/blob/main/doc/source/tutorial-series-use-a-federated-learning-strategy-pytorch.ipynb) (or open the [Jupyter Notebook](https://github.com/adap/flower/blob/main/doc/source/tutorial-series-use-a-federated-learning-strategy-pytorch.ipynb)) + [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adap/flower/blob/main/doc/source/tutorial-series-build-a-strategy-from-scratch-pytorch.ipynb) (or open the [Jupyter Notebook](https://github.com/adap/flower/blob/main/doc/source/tutorial-series-build-a-strategy-from-scratch-pytorch.ipynb)) 4. **Custom Clients for Federated Learning** @@ -74,20 +73,19 @@ Stay tuned, more tutorials are coming soon. Topics include **Privacy and Securit ## Documentation -[Flower Docs](https://flower.dev/docs): - -- [Installation](https://flower.dev/docs/framework/how-to-install-flower.html) -- [Quickstart (TensorFlow)](https://flower.dev/docs/framework/tutorial-quickstart-tensorflow.html) -- [Quickstart (PyTorch)](https://flower.dev/docs/framework/tutorial-quickstart-pytorch.html) -- [Quickstart (Hugging Face)](https://flower.dev/docs/framework/tutorial-quickstart-huggingface.html) -- [Quickstart (PyTorch Lightning [code example])](https://flower.dev/docs/framework/tutorial-quickstart-pytorch-lightning.html) -- [Quickstart (MXNet)](https://flower.dev/docs/framework/example-mxnet-walk-through.html) -- [Quickstart (Pandas)](https://flower.dev/docs/framework/tutorial-quickstart-pandas.html) -- [Quickstart (fastai)](https://flower.dev/docs/framework/tutorial-quickstart-fastai.html) -- [Quickstart (JAX)](https://flower.dev/docs/framework/tutorial-quickstart-jax.html) -- [Quickstart (scikit-learn)](https://flower.dev/docs/framework/tutorial-quickstart-scikitlearn.html) -- [Quickstart (Android [TFLite])](https://flower.dev/docs/framework/tutorial-quickstart-android.html) -- [Quickstart (iOS [CoreML])](https://flower.dev/docs/framework/tutorial-quickstart-ios.html) +[Flower Docs](https://flower.ai/docs): + +- [Installation](https://flower.ai/docs/framework/how-to-install-flower.html) +- [Quickstart (TensorFlow)](https://flower.ai/docs/framework/tutorial-quickstart-tensorflow.html) +- [Quickstart (PyTorch)](https://flower.ai/docs/framework/tutorial-quickstart-pytorch.html) +- [Quickstart (Hugging Face)](https://flower.ai/docs/framework/tutorial-quickstart-huggingface.html) +- [Quickstart (PyTorch Lightning)](https://flower.ai/docs/framework/tutorial-quickstart-pytorch-lightning.html) +- [Quickstart (Pandas)](https://flower.ai/docs/framework/tutorial-quickstart-pandas.html) +- [Quickstart (fastai)](https://flower.ai/docs/framework/tutorial-quickstart-fastai.html) +- [Quickstart (JAX)](https://flower.ai/docs/framework/tutorial-quickstart-jax.html) +- [Quickstart (scikit-learn)](https://flower.ai/docs/framework/tutorial-quickstart-scikitlearn.html) +- [Quickstart (Android [TFLite])](https://flower.ai/docs/framework/tutorial-quickstart-android.html) +- [Quickstart (iOS [CoreML])](https://flower.ai/docs/framework/tutorial-quickstart-ios.html) ## Flower Baselines @@ -100,17 +98,24 @@ Flower Baselines is a collection of community-contributed projects that reproduc - [FedMLB](https://github.com/adap/flower/tree/main/baselines/fedmlb) - [FedPer](https://github.com/adap/flower/tree/main/baselines/fedper) - [FedProx](https://github.com/adap/flower/tree/main/baselines/fedprox) +- [FedNova](https://github.com/adap/flower/tree/main/baselines/fednova) +- [HeteroFL](https://github.com/adap/flower/tree/main/baselines/heterofl) +- [FedAvgM](https://github.com/adap/flower/tree/main/baselines/fedavgm) +- [FedStar](https://github.com/adap/flower/tree/main/baselines/fedstar) - [FedWav2vec2](https://github.com/adap/flower/tree/main/baselines/fedwav2vec2) - [FjORD](https://github.com/adap/flower/tree/main/baselines/fjord) - [MOON](https://github.com/adap/flower/tree/main/baselines/moon) - [niid-Bench](https://github.com/adap/flower/tree/main/baselines/niid_bench) - [TAMUNA](https://github.com/adap/flower/tree/main/baselines/tamuna) +- [FedVSSL](https://github.com/adap/flower/tree/main/baselines/fedvssl) +- [FedXGBoost](https://github.com/adap/flower/tree/main/baselines/hfedxgboost) +- [FedPara](https://github.com/adap/flower/tree/main/baselines/fedpara) - [FedAvg](https://github.com/adap/flower/tree/main/baselines/flwr_baselines/flwr_baselines/publications/fedavg_mnist) - [FedOpt](https://github.com/adap/flower/tree/main/baselines/flwr_baselines/flwr_baselines/publications/adaptive_federated_optimization) -Please refer to the [Flower Baselines Documentation](https://flower.dev/docs/baselines/) for a detailed categorization of baselines and for additional info including: -* [How to use Flower Baselines](https://flower.dev/docs/baselines/how-to-use-baselines.html) -* [How to contribute a new Flower Baseline](https://flower.dev/docs/baselines/how-to-contribute-baselines.html) +Please refer to the [Flower Baselines Documentation](https://flower.ai/docs/baselines/) for a detailed categorization of baselines and for additional info including: +* [How to use Flower Baselines](https://flower.ai/docs/baselines/how-to-use-baselines.html) +* [How to contribute a new Flower Baseline](https://flower.ai/docs/baselines/how-to-contribute-baselines.html) ## Flower Usage Examples @@ -124,24 +129,32 @@ Quickstart examples: - [Quickstart (PyTorch Lightning)](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch-lightning) - [Quickstart (fastai)](https://github.com/adap/flower/tree/main/examples/quickstart-fastai) - [Quickstart (Pandas)](https://github.com/adap/flower/tree/main/examples/quickstart-pandas) -- [Quickstart (MXNet)](https://github.com/adap/flower/tree/main/examples/quickstart-mxnet) - [Quickstart (JAX)](https://github.com/adap/flower/tree/main/examples/quickstart-jax) +- [Quickstart (MONAI)](https://github.com/adap/flower/tree/main/examples/quickstart-monai) - [Quickstart (scikit-learn)](https://github.com/adap/flower/tree/main/examples/sklearn-logreg-mnist) - [Quickstart (Android [TFLite])](https://github.com/adap/flower/tree/main/examples/android) - [Quickstart (iOS [CoreML])](https://github.com/adap/flower/tree/main/examples/ios) +- [Quickstart (MLX)](https://github.com/adap/flower/tree/main/examples/quickstart-mlx) +- [Quickstart (XGBoost)](https://github.com/adap/flower/tree/main/examples/xgboost-quickstart) Other [examples](https://github.com/adap/flower/tree/main/examples): - [Raspberry Pi & Nvidia Jetson Tutorial](https://github.com/adap/flower/tree/main/examples/embedded-devices) - [PyTorch: From Centralized to Federated](https://github.com/adap/flower/tree/main/examples/pytorch-from-centralized-to-federated) -- [MXNet: From Centralized to Federated](https://github.com/adap/flower/tree/main/examples/mxnet-from-centralized-to-federated) +- [Vertical FL](https://github.com/adap/flower/tree/main/examples/vertical-fl) +- [Federated Finetuning of OpenAI's Whisper](https://github.com/adap/flower/tree/main/examples/whisper-federated-finetuning) +- [Federated Finetuning of Large Language Model](https://github.com/adap/flower/tree/main/examples/fedllm-finetune) +- [Federated Finetuning of a Vision Transformer](https://github.com/adap/flower/tree/main/examples/vit-finetune) - [Advanced Flower with TensorFlow/Keras](https://github.com/adap/flower/tree/main/examples/advanced-tensorflow) - [Advanced Flower with PyTorch](https://github.com/adap/flower/tree/main/examples/advanced-pytorch) -- Single-Machine Simulation of Federated Learning Systems ([PyTorch](https://github.com/adap/flower/tree/main/examples/simulation_pytorch)) ([Tensorflow](https://github.com/adap/flower/tree/main/examples/simulation_tensorflow)) +- Single-Machine Simulation of Federated Learning Systems ([PyTorch](https://github.com/adap/flower/tree/main/examples/simulation-pytorch)) ([Tensorflow](https://github.com/adap/flower/tree/main/examples/simulation-tensorflow)) +- [Comprehensive Flower+XGBoost](https://github.com/adap/flower/tree/main/examples/xgboost-comprehensive) +- [Flower through Docker Compose and with Grafana dashboard](https://github.com/adap/flower/tree/main/examples/flower-via-docker-compose) +- [Flower with KaplanMeierFitter from the lifelines library](https://github.com/adap/flower/tree/main/examples/federated-kaplna-meier-fitter) ## Community -Flower is built by a wonderful community of researchers and engineers. [Join Slack](https://flower.dev/join-slack) to meet them, [contributions](#contributing-to-flower) are welcome. +Flower is built by a wonderful community of researchers and engineers. [Join Slack](https://flower.ai/join-slack) to meet them, [contributions](#contributing-to-flower) are welcome. diff --git a/baselines/README.md b/baselines/README.md index a18c0553b2b4..3a84df02d8de 100644 --- a/baselines/README.md +++ b/baselines/README.md @@ -1,7 +1,7 @@ # Flower Baselines -> We are changing the way we structure the Flower baselines. While we complete the transition to the new format, you can still find the existing baselines in the `flwr_baselines` directory. Currently, you can make use of baselines for [FedAvg](https://github.com/adap/flower/tree/main/baselines/flwr_baselines/flwr_baselines/publications/fedavg_mnist), [FedOpt](https://github.com/adap/flower/tree/main/baselines/flwr_baselines/flwr_baselines/publications/adaptive_federated_optimization), and [LEAF-FEMNIST](https://github.com/adap/flower/tree/main/baselines/flwr_baselines/flwr_baselines/publications/leaf/femnist). +> We are changing the way we structure the Flower baselines. While we complete the transition to the new format, you can still find the existing baselines in the `flwr_baselines` directory. Currently, you can make use of baselines for [FedAvg](https://github.com/adap/flower/tree/main/baselines/flwr_baselines/flwr_baselines/publications/fedavg_mnist), [FedOpt](https://github.com/adap/flower/tree/main/baselines/flwr_baselines/flwr_baselines/publications/adaptive_federated_optimization), and [LEAF-FEMNIST](https://github.com/adap/flower/tree/main/baselines/flwr_baselines/flwr_baselines/publications/leaf/femnist). > The documentation below has been updated to reflect the new way of using Flower baselines. @@ -23,7 +23,7 @@ Please note that some baselines might include additional files (e.g. a `requirem ## Running the baselines -Each baseline is self-contained in its own directory. Furthermore, each baseline defines its own Python environment using [Poetry](https://python-poetry.org/docs/) via a `pyproject.toml` file and [`pyenv`](https://github.com/pyenv/pyenv). If you haven't setup `Poetry` and `pyenv` already on your machine, please take a look at the [Documentation](https://flower.dev/docs/baselines/how-to-use-baselines.html#setting-up-your-machine) for a guide on how to do so. +Each baseline is self-contained in its own directory. Furthermore, each baseline defines its own Python environment using [Poetry](https://python-poetry.org/docs/) via a `pyproject.toml` file and [`pyenv`](https://github.com/pyenv/pyenv). If you haven't setup `Poetry` and `pyenv` already on your machine, please take a look at the [Documentation](https://flower.ai/docs/baselines/how-to-use-baselines.html#setting-up-your-machine) for a guide on how to do so. Assuming `pyenv` and `Poetry` are already installed on your system. Running a baseline can be done by: @@ -54,7 +54,7 @@ The steps to follow are: ```bash # This will create a new directory with the same structure as `baseline_template`. ./dev/create-baseline.sh - ``` + ``` 3. Then, go inside your baseline directory and continue with the steps detailed in `EXTENDED_README.md` and `README.md`. 4. Once your code is ready and you have checked that following the instructions in your `README.md` the Python environment can be created correctly and that running the code following your instructions can reproduce the experiments in the paper, you just need to create a Pull Request (PR). Then, the process to merge your baseline into the Flower repo will begin! diff --git a/baselines/baseline_template/pyproject.toml b/baselines/baseline_template/pyproject.toml index da7516437f09..31f1ee7bfe6d 100644 --- a/baselines/baseline_template/pyproject.toml +++ b/baselines/baseline_template/pyproject.toml @@ -7,11 +7,11 @@ name = "" # <----- Ensure it matches the name of your baseline di version = "1.0.0" description = "Flower Baselines" license = "Apache-2.0" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] readme = "README.md" -homepage = "https://flower.dev" +homepage = "https://flower.ai" repository = "https://github.com/adap/flower" -documentation = "https://flower.dev" +documentation = "https://flower.ai" classifiers = [ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", @@ -42,9 +42,9 @@ flwr = { extras = ["simulation"], version = "1.5.0" } hydra-core = "1.3.2" # don't change this [tool.poetry.dev-dependencies] -isort = "==5.11.5" -black = "==23.1.0" -docformatter = "==1.5.1" +isort = "==5.13.2" +black = "==24.2.0" +docformatter = "==1.7.5" mypy = "==1.4.1" pylint = "==2.8.2" flake8 = "==3.9.2" @@ -80,10 +80,10 @@ plugins = "numpy.typing.mypy_plugin" [tool.pylint."MESSAGES CONTROL"] disable = "bad-continuation,duplicate-code,too-few-public-methods,useless-import-alias" good-names = "i,j,k,_,x,y,X,Y" -signature-mutators="hydra.main.main" +signature-mutators = "hydra.main.main" [tool.pylint.typecheck] -generated-members="numpy.*, torch.*, tensorflow.*" +generated-members = "numpy.*, torch.*, tensorflow.*" [[tool.mypy.overrides]] module = [ diff --git a/baselines/dasha/pyproject.toml b/baselines/dasha/pyproject.toml index 3ef24e4b985a..f03ad06e26e4 100644 --- a/baselines/dasha/pyproject.toml +++ b/baselines/dasha/pyproject.toml @@ -9,9 +9,9 @@ description = "DASHA: Distributed nonconvex optimization with communication comp license = "Apache-2.0" authors = ["Alexander Tyurin "] readme = "README.md" -homepage = "https://flower.dev" +homepage = "https://flower.ai" repository = "https://github.com/adap/flower" -documentation = "https://flower.dev" +documentation = "https://flower.ai" classifiers = [ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", @@ -49,9 +49,9 @@ torchvision = [{ url = "https://download.pytorch.org/whl/cu118/torchvision-0.15. { url = "https://download.pytorch.org/whl/cpu/torchvision-0.15.0-cp39-cp39-macosx_11_0_arm64.whl", markers="sys_platform == 'darwin'"}] [tool.poetry.dev-dependencies] -isort = "==5.11.5" -black = "==23.1.0" -docformatter = "==1.5.1" +isort = "==5.13.2" +black = "==24.2.0" +docformatter = "==1.7.5" mypy = "==1.4.1" pylint = "==2.8.2" flake8 = "==3.9.2" @@ -60,6 +60,7 @@ pytest-watch = "==4.2.0" types-requests = "==2.27.7" py-spy = "==0.3.14" ruff = "==0.0.272" +virtualenv = "==20.21.0" [tool.isort] line_length = 88 @@ -88,8 +89,8 @@ plugins = "numpy.typing.mypy_plugin" [tool.pylint."MESSAGES CONTROL"] disable = "bad-continuation,duplicate-code,too-few-public-methods,useless-import-alias,no-self-use,too-many-locals,too-many-instance-attributes" good-names = "i,j,k,_,x,y,X,Y" -signature-mutators="hydra.main.main" -generated-members="numpy.*, torch.*" +signature-mutators = "hydra.main.main" +generated-members = "numpy.*, torch.*" [[tool.mypy.overrides]] module = [ diff --git a/baselines/depthfl/README.md b/baselines/depthfl/README.md index b8ab7ed18571..ab26f9c8f8c9 100644 --- a/baselines/depthfl/README.md +++ b/baselines/depthfl/README.md @@ -124,7 +124,7 @@ python -m depthfl.main --multirun exclusive_learning=true model_size=1,2,3,4 **Table 2** -100% (a), 75%(b), 50%(c), 25% (d) cases are exclusive learning scenario. 100% (a) exclusive learning means, the global model and every local model are equal to the smallest local model, and 100% clients participate in learning. Likewise, 25% (d) exclusive learning means, the global model and every local model are equal to the larget local model, and only 25% clients participate in learning. +100% (a), 75%(b), 50%(c), 25% (d) cases are exclusive learning scenario. 100% (a) exclusive learning means, the global model and every local model are equal to the smallest local model, and 100% clients participate in learning. Likewise, 25% (d) exclusive learning means, the global model and every local model are equal to the largest local model, and only 25% clients participate in learning. | Scaling Method | Dataset | Global Model | 100% (a) | 75% (b) | 50% (c) | 25% (d) | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | diff --git a/baselines/depthfl/depthfl/models.py b/baselines/depthfl/depthfl/models.py index df3eebf9f9ce..5ce5eedb360e 100644 --- a/baselines/depthfl/depthfl/models.py +++ b/baselines/depthfl/depthfl/models.py @@ -1,6 +1,5 @@ """ResNet18 model architecutre, training, and testing functions for CIFAR100.""" - from typing import List, Tuple import torch diff --git a/baselines/depthfl/pyproject.toml b/baselines/depthfl/pyproject.toml index 2f928c2d3553..e59d16fef88e 100644 --- a/baselines/depthfl/pyproject.toml +++ b/baselines/depthfl/pyproject.toml @@ -9,9 +9,9 @@ description = "DepthFL: Depthwise Federated Learning for Heterogeneous Clients" license = "Apache-2.0" authors = ["Minjae Kim "] readme = "README.md" -homepage = "https://flower.dev" +homepage = "https://flower.ai" repository = "https://github.com/adap/flower" -documentation = "https://flower.dev" +documentation = "https://flower.ai" classifiers = [ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", @@ -37,18 +37,17 @@ classifiers = [ ] [tool.poetry.dependencies] -python = ">=3.10.0, <3.11.0" +python = ">=3.10.0, <3.11.0" flwr = { extras = ["simulation"], version = "1.5.0" } hydra-core = "1.3.2" # don't change this matplotlib = "3.7.1" -torch = { url = "https://download.pytorch.org/whl/cu116/torch-1.13.1%2Bcu116-cp310-cp310-linux_x86_64.whl"} -torchvision = { url = "https://download.pytorch.org/whl/cu116/torchvision-0.14.1%2Bcu116-cp310-cp310-linux_x86_64.whl"} - +torch = { url = "https://download.pytorch.org/whl/cu116/torch-1.13.1%2Bcu116-cp310-cp310-linux_x86_64.whl" } +torchvision = { url = "https://download.pytorch.org/whl/cu116/torchvision-0.14.1%2Bcu116-cp310-cp310-linux_x86_64.whl" } [tool.poetry.dev-dependencies] -isort = "==5.11.5" -black = "==23.1.0" -docformatter = "==1.5.1" +isort = "==5.13.2" +black = "==24.2.0" +docformatter = "==1.7.5" mypy = "==1.4.1" pylint = "==2.8.2" flake8 = "==3.9.2" @@ -56,6 +55,7 @@ pytest = "==6.2.4" pytest-watch = "==4.2.0" ruff = "==0.0.272" types-requests = "==2.27.7" +virtualenv = "==20.21.0" [tool.isort] line_length = 88 @@ -84,10 +84,10 @@ plugins = "numpy.typing.mypy_plugin" [tool.pylint."MESSAGES CONTROL"] disable = "bad-continuation,duplicate-code,too-few-public-methods,useless-import-alias" good-names = "i,j,k,_,x,y,X,Y" -signature-mutators="hydra.main.main" +signature-mutators = "hydra.main.main" [tool.pylint.typecheck] -generated-members="numpy.*, torch.*, tensorflow.*" +generated-members = "numpy.*, torch.*, tensorflow.*" [[tool.mypy.overrides]] module = [ diff --git a/baselines/doc/source/_templates/base.html b/baselines/doc/source/_templates/base.html index 1cee99053fbe..cc171be9e6b0 100644 --- a/baselines/doc/source/_templates/base.html +++ b/baselines/doc/source/_templates/base.html @@ -5,7 +5,7 @@ - + {%- if metatags %}{{ metatags }}{% endif -%} @@ -99,6 +99,6 @@ {%- endblock -%} {%- endblock scripts -%} - + diff --git a/baselines/doc/source/conf.py b/baselines/doc/source/conf.py index e1055aaed216..dad8650cddaa 100644 --- a/baselines/doc/source/conf.py +++ b/baselines/doc/source/conf.py @@ -14,6 +14,7 @@ # ============================================================================== +import datetime import os import sys from sphinx.application import ConfigError @@ -32,7 +33,7 @@ # -- Project information ----------------------------------------------------- project = "Flower" -copyright = "2022 Flower Labs GmbH" +copyright = f"{datetime.date.today().year} Flower Labs GmbH" author = "The Flower Authors" # The full version, including alpha/beta/rc tags @@ -84,7 +85,7 @@ html_title = f"Flower Baselines {release}" html_logo = "_static/flower-logo.png" html_favicon = "_static/favicon.ico" -html_baseurl = "https://flower.dev/docs/baselines/" +html_baseurl = "https://flower.ai/docs/baselines/" html_theme_options = { # diff --git a/baselines/doc/source/how-to-use-baselines.rst b/baselines/doc/source/how-to-use-baselines.rst index ed47438ad5a9..4704a9b6074e 100644 --- a/baselines/doc/source/how-to-use-baselines.rst +++ b/baselines/doc/source/how-to-use-baselines.rst @@ -33,7 +33,7 @@ Setting up your machine Common to all baselines is `Poetry `_, a tool to manage Python dependencies. Baselines also make use of `Pyenv `_. You'll need to install both on your system before running a baseline. What follows is a step-by-step guide on getting :code:`pyenv` and :code:`Poetry` installed on your system. -Let's begin by installing :code:`pyenv`. We'll be following the standard procedure. Please refere to the `pyenv docs `_ for alternative ways of installing it. +Let's begin by installing :code:`pyenv`. We'll be following the standard procedure. Please refer to the `pyenv docs `_ for alternative ways of installing it. .. code-block:: bash @@ -49,7 +49,7 @@ Let's begin by installing :code:`pyenv`. We'll be following the standard procedu command -v pyenv >/dev/null || export PATH="$PYENV_ROOT/bin:$PATH" eval "$(pyenv init -)" -Verify your installtion by opening a new terminal and +Verify your installation by opening a new terminal and .. code-block:: bash @@ -63,7 +63,7 @@ Then you can proceed and install any version of Python. Most baselines currently pyenv install 3.10.6 # this will take a little while - # once done, you should see that that version is avaialble + # once done, you should see that that version is available pyenv versions # system # * 3.10.6 # <-- you just installed this diff --git a/baselines/doc/source/index.rst b/baselines/doc/source/index.rst index 335cfacef1ab..3a19e74b891e 100644 --- a/baselines/doc/source/index.rst +++ b/baselines/doc/source/index.rst @@ -1,7 +1,7 @@ Flower Baselines Documentation ============================== -Welcome to Flower Baselines' documentation. `Flower `_ is a friendly federated learning framework. +Welcome to Flower Baselines' documentation. `Flower `_ is a friendly federated learning framework. Join the Flower Community @@ -9,7 +9,7 @@ Join the Flower Community The Flower Community is growing quickly - we're a friendly group of researchers, engineers, students, professionals, academics, and other enthusiasts. -.. button-link:: https://flower.dev/join-slack +.. button-link:: https://flower.ai/join-slack :color: primary :shadow: diff --git a/baselines/fedavgm/LICENSE b/baselines/fedavgm/LICENSE new file mode 100644 index 000000000000..d64569567334 --- /dev/null +++ b/baselines/fedavgm/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/baselines/fedavgm/README.md b/baselines/fedavgm/README.md new file mode 100644 index 000000000000..5b8ddfcad6a8 --- /dev/null +++ b/baselines/fedavgm/README.md @@ -0,0 +1,220 @@ +--- +title: Measuring the effects of non-identical data distribution for federated visual classification +url: https://arxiv.org/abs/1909.06335 +labels: [non-iid, image classification] +dataset: [CIFAR-10, Fashion-MNIST] +--- + +# Measuring the effects of non-identical data distribution for federated visual classification + +> Note: If you use this baseline in your work, please remember to cite the original authors of the paper as well as the Flower paper. + +**Paper:** [arxiv.org/abs/1909.06335](https://arxiv.org/abs/1909.06335) + +**Authors:** Tzu-Ming Harry Hsu, Hang Qi, Matthew Brown + +**Abstract:** Federated Learning enables visual models to be trained in a privacy-preserving way using real-world data from mobile devices. Given their distributed nature, the statistics of the data across these devices is likely to differ significantly. In this work, we look at the effect such non-identical data distributions has on visual classification via Federated Learning. We propose a way to synthesize datasets with a continuous range of identicalness and provide performance measures for the Federated Averaging algorithm. We show that performance degrades as distributions differ more, and propose a mitigation strategy via server momentum. Experiments on CIFAR-10 demonstrate improved classification performance over a range of non-identicalness, with classification accuracy improved from 30.1% to 76.9% in the most skewed settings. + + +## About this baseline + +**What’s implemented:** The code in this directory evaluates the effects of non-identical data distribution for visual classification task based on paper _Measuring the effects of non-identical data distribution for federated visual classification_ (Hsu et al., 2019). It reproduces the FedAvgM and FedAvg performance curves for different non-identical-ness of the dataset (CIFAR-10 and Fashion-MNIST). _Figure 5 in the paper, section 4.2._ + +**Datasets:** CIFAR-10, and Fashion-MNIST + +**Hardware Setup:** This baseline was evaluated in a regular PC without GPU (Intel i7-10710U CPU, and 32 Gb RAM). The major constraint is to run a huge number of rounds such as the reference paper that reports 10.000 round for each case evaluated. + +**Contributors:** Gustavo Bertoli [(@gubertoli)](https://github.com/gubertoli) + +## Experimental Setup + +**Task:** Image Classification + +**Model:** This directory implements a CNN model similar to the one used on the seminal FedAvg paper (`models.py`): + +- McMahan, B., Moore, E., Ramage, D., Hampson, S., & y Arcas, B. A. (2017, April). Communication-efficient learning of deep networks from decentralized data. In Artificial intelligence and statistics (pp. 1273-1282). PMLR. ([Link](http://proceedings.mlr.press/v54/mcmahan17a/mcmahan17a.pdf)): + +As the following excerpt: + +> "*We also ran experiments on the CIFAR-10 dataset... The model architecture was taken from the TensorFlow tutorial [38], which consists of two convolutional layers followed by two fully connected layers and then a linear transformation layer to produce logits, for a total of about 10 parameters."* + +Regarding this architecture, the historical references mentioned on the FedAvg and FedAvgM papers are [this](https://web.archive.org/web/20190415103404/https://www.tensorflow.org/tutorials/images/deep_cnn) and [this](https://web.archive.org/web/20170807002954/https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10/cifar10.py). + +Important to highlight the rationale with this CNN model stated on FedAvgM paper: + +> "*This model is not the state-of-the-art on the CIFAR-10 dataset, but is sufficient to show relative performance for the purposes of our investigation."* + +**The default CNN model in use on this baseline have a centralized accuracy of ~0.74. That is different from the reported 0.86 accuracy from the original FedAvg paper. But it is still sufficient to show the relative performance for the purposes of FedAvgM investigation.** + +**Dataset:** This baseline includes the CIFAR-10 and Fashion-MNIST datasets. By default it will run with the CIFAR-10. The data partition uses a configurable Latent Dirichlet Allocation (LDA) distribution (`concentration` parameter equals 0.1 as default) to create **non-iid distributions** between the clients. The understanding for this `concentration` (α) is that α→∞ all clients have identical distribution, and α→𝟢 each client hold samples from only one class. + +| Dataset | # classes | # partitions | partition method | partition settings| +| :------ | :---: | :---: | :---: | :---: | +| CIFAR-10 | 10 | `num_clients` | Latent Dirichlet Allocation (LDA) | `concentration` | +| Fashion-MNIST | 10 | `num_clients` | Latent Dirichlet Allocation (LDA) | `concentration` | + +**Data distribution:** The following figure illustrates the use of multiple `concentration` values to generate the data distribution over 30 clients for CIFAR-10 (10 classes) - [source code](fedavgm/utils.py): + +![](_static/concentration_cifar10_v2.png) + +**Training Hyperparameters:** +The following table shows the main hyperparameters for this baseline with their default value (i.e. the value used if you run `python main.py` directly) + +| Description | Default Value | +| ----------- | ----- | +| total clients | 10 | +| number of rounds | 5 | +| model | CNN | +| strategy | Custom FedAvgM | +| dataset | CIFAR-10 | +| concentration | 0.1 | +| fraction evaluate | 0 | +| num cpus | 1 | +| num gpus | 0 | +| server momentum | 0.9 | +| server learning rate | 1.0 | +| server reporting fraction | 0.05 | +| client local epochs | 1 | +| client batch size | 64 | +| client learning rate | 0.01 | + +### Custom FedAvgM +In contrast to the initial implementation found in Flower v1.5.0, our baseline incorporates the Nesterov accelerated gradient as a pivotal component of the momentum applied to the server model. It is worth emphasizing that the inclusion of Nesterov momentum aligns with the original definition of FedAvgM in the research paper. + +To use the original Flower implementation, use the argument `strategy=fedavgm`. By default, the custom implementation is used. But, you can also refer to it on the command line as `strategy=custom-fedavgm`. + +## Environment Setup + +### Specifying the Python Version + +This baseline was tested with Python 3.10.6 and following the steps below to construct the Python environment and install all dependencies. Both [`pyenv`](https://github.com/pyenv/pyenv) and [`poetry`](https://python-poetry.org/docs/) are assumed to be already present in your system. + +```bash +# Cd to your baseline directory (i.e. where the `pyproject.toml` is), then +pyenv local 3.10.6 + +# Set that version for poetry +poetry env use 3.10.6 + +# Install the base Poetry environment +poetry install + +# Activate the environment +poetry shell +``` + +### Google Colab +If you want to setup the environment on Google Colab, please executed the script `conf-colab.sh`, just use the Colab terminal and the following: + +```bash +chmod +x conf-colab.sh +./conf-colab.sh +``` + +## Running the Experiments + +To run this FedAvgM with CIFAR-10 baseline, first ensure you have activated your Poetry environment (execute `poetry shell` from this directory), then: + +```bash +python -m fedavgm.main # this will run using the default setting in the `conf/base.yaml` + +# you can override settings directly from the command line + +python -m fedavgm.main strategy=fedavg num_clients=1000 num_rounds=50 # will set the FedAvg with 1000 clients and 50 rounds + +python -m fedavgm.main dataset=fmnist noniid.concentration=10 # use the Fashion-MNIST dataset and a different concentration for the LDA-based partition + +python -m fedavgm.main server.reporting_fraction=0.2 client.local_epochs=5 # will set the reporting fraction to 20% and the local epochs in the clients to 5 +``` + +## Expected Results + +### CIFAR-10 +Similar to FedAvgM paper as reference, the CIFAR-10 evaluation runs 10,000 rounds. + +> In order to speedup the execution of these experiments, the evaluation of the _global model_ on the test set only takes place after the last round. The highest accuracy is achieved towards the last rounds, not necessarily in the last. If you wish to evaluate the _global model_ on the test set (or a validation set) more frequently, edit `get_evaluate_fn` in `server.py`. Overal, running the experiments as shown below demonstrate that `FedAvgM` is consistently superior to `FedAvg`. + +For FedAvgM evaluation, it was performed a hyperparameter search of server momentum and client learning rate (similar to Figure 6 reported below) for each of the concentrations under analysis, using the following commands: + +- Concentration = 1e-5 and 1e-9 (extreme non-iid) +```bash +python -m fedavgm.main --multirun client.local_epochs=1 noniid.concentration=1e-5,1e-9 strategy=custom-fedavgm,fedavg \ +server.reporting_fraction=0.05 num_rounds=10000 num_clients=100 \ +dataset=cifar10 client.lr=0.0003 server.momentum=0.99 +``` + +- Concentration = 0.01 +```bash +python -m fedavgm.main --multirun client.local_epochs=1 noniid.concentration=0.01 strategy=custom-fedavgm,fedavg \ +server.reporting_fraction=0.05 num_rounds=10000 num_clients=100 \ +dataset=cifar10 client.lr=0.003 server.momentum=0.97 +``` + +- Concentration = 0.1 +```bash +python -m fedavgm.main --multirun client.local_epochs=1 noniid.concentration=0.1 strategy=custom-fedavgm,fedavg \ +server.reporting_fraction=0.05 num_rounds=10000 num_clients=100 \ +dataset=cifar10 client.lr=0.0003 server.momentum=0.99 +``` + +- Concentration = 1 +```bash +python -m fedavgm.main --multirun client.local_epochs=1 noniid.concentration=1 strategy=custom-fedavgm,fedavg \ +server.reporting_fraction=0.05 num_rounds=10000 num_clients=100 \ +dataset=cifar10 client.lr=0.0003 server.momentum=0.997 +``` + +- Concentration = 10 +```bash +python -m fedavgm.main --multirun client.local_epochs=1 noniid.concentration=10 strategy=custom-fedavgm,fedavg \ +server.reporting_fraction=0.05 num_rounds=10000 num_clients=100 \ +dataset=cifar10 client.lr=0.003 server.momentum=0.9 +``` + +Summarizing all the results: + +![](_static/fedavgm_vs_fedavg_rounds=10000_cifar10_w_1e-9.png) + +The findings aligns with the report on the original FedAvgM paper that *"To prevent client updates from diverging, we additionally have to use a combination of low absolute learning rate and high momentum"*. + +The following command reproduces the same behavior of Figure 6 from FedAvgM paper for the case of Local Epoch E=1, Reporting Fraction C=0.05, and concentration (α) = 1. In this example, it runs just 1,000 rounds: + +```bash +python -m fedavgm.main --multirun client.local_epochs=1 noniid.concentration=1 \ +strategy=custom-fedavgm server.reporting_fraction=0.05 num_rounds=100 num_clients=100 \ +dataset=cifar10 client.lr=0.0001,0.0003,0.001,0.003,0.01,0.03,0.1,0.3 \ +server.momentum=0.7,0.9,0.97,0.99,0.997 +``` + +![](_static/Figure6_cifar10_num-rounds=1000_concentration=1.png) + + +--- +### Fashion-MNIST + +```bash +python -m fedavgm.main --multirun client.local_epochs=1 \ +noniid.concentration=0.001,0.01,0.1,1,10,100 strategy=custom-fedavgm,fedavg \ +server.reporting_fraction=0.05 num_rounds=1000 \ +num_clients=100 dataset=fmnist server.momentum=0.97 client.lr=0.003 +``` +The above command will evaluate the custom FedAvgM versus FedAvg on Fashion-MNIST datasets. It uses 100 clients with a reporting fraction of 5% during 1000 rounds. To evaluate the non-iid aspects, this exececution exercises concentration of [100, 10, 1, 0.1, 0.01, 0.001]: + +![](_static/fedavgm_vs_fedavg_rounds=1000_fmnist.png) + +#### Comparison between the Custom-FedAvgM and FedAvgM + +To compare the improvement of the FedAvgM with Nesterov momentum (`strategy=custom-fedavgm`) and the FedAvgM without the Nesterov momentum (`strategy=fedavgm`), here we use the results of previous running with addition of the same conditions for the `fedavgm` strategy as follows: + +```bash +python -m fedavgm.main --multirun client.local_epochs=1 \ +noniid.concentration=0.001,0.01,0.1,1,10,100 strategy=fedavgm \ +server.reporting_fraction=0.05 num_rounds=1000 \ +num_clients=100 dataset=fmnist server.momentum=0.97 client.lr=0.003 +``` + +![](_static/custom-fedavgm_vs_fedavgm_rounds=1000_fmnist.png) + +Overall, FedAvgM with Nesterov momentum outperforms the FedAvgM without Nesterov momentum, being clear this behavior for higher non-iidness (0.01 and 0.001). In these higher non-iidness, the test accuracy for FedAvg without Nesterov momentum are worse than the FedAvg. +For larger concentrations (1, 10, 100), it was observed some runs that the centralized evaluation resulted in a loss equal NaN or Inf, thus it was required multiple runs to guarantee the accuracies reported. + diff --git a/baselines/fedavgm/_static/Comparison_CNN_vs_TF_v1_x_Example_for_CIFAR_10.ipynb b/baselines/fedavgm/_static/Comparison_CNN_vs_TF_v1_x_Example_for_CIFAR_10.ipynb new file mode 100644 index 000000000000..fac837d145e3 --- /dev/null +++ b/baselines/fedavgm/_static/Comparison_CNN_vs_TF_v1_x_Example_for_CIFAR_10.ipynb @@ -0,0 +1,1851 @@ +{ + "cells": [ + { + "cell_type": "code", + "source": [ + "import numpy as np\n", + "from keras.optimizers import SGD\n", + "from keras.regularizers import l2\n", + "from tensorflow import keras\n", + "from tensorflow.nn import local_response_normalization\n", + "from keras.utils import to_categorical\n", + "import matplotlib.pyplot as plt" + ], + "metadata": { + "id": "Rp9LUn54SUTu" + }, + "execution_count": 15, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "id": "7tTxE8D6bD6g" + }, + "outputs": [], + "source": [ + "def tf_example(input_shape, num_classes):\n", + " \"\"\"CNN Model from TensorFlow v1.x example.\n", + "\n", + " This is the model referenced on the FedAvg paper.\n", + "\n", + " Reference:\n", + " https://web.archive.org/web/20170807002954/https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10/cifar10.py\n", + " \"\"\"\n", + " input_shape = tuple(input_shape)\n", + "\n", + " weight_decay = 0.004\n", + " model = keras.Sequential(\n", + " [\n", + " keras.layers.Conv2D(\n", + " 64,\n", + " (5, 5),\n", + " padding=\"same\",\n", + " activation=\"relu\",\n", + " input_shape=input_shape,\n", + " ),\n", + " keras.layers.MaxPooling2D((3, 3), strides=(2, 2), padding=\"same\"),\n", + " keras.layers.Lambda(\n", + " local_response_normalization,\n", + " arguments={\n", + " \"depth_radius\": 4,\n", + " \"bias\": 1.0,\n", + " \"alpha\": 0.001 / 9.0,\n", + " \"beta\": 0.75,\n", + " },\n", + " ),\n", + " keras.layers.Conv2D(\n", + " 64,\n", + " (5, 5),\n", + " padding=\"same\",\n", + " activation=\"relu\",\n", + " ),\n", + " keras.layers.Lambda(\n", + " local_response_normalization,\n", + " arguments={\n", + " \"depth_radius\": 4,\n", + " \"bias\": 1.0,\n", + " \"alpha\": 0.001 / 9.0,\n", + " \"beta\": 0.75,\n", + " },\n", + " ),\n", + " keras.layers.MaxPooling2D((3, 3), strides=(2, 2), padding=\"same\"),\n", + " keras.layers.Flatten(),\n", + " keras.layers.Dense(\n", + " 384, activation=\"relu\", kernel_regularizer=l2(weight_decay)\n", + " ),\n", + " keras.layers.Dense(\n", + " 192, activation=\"relu\", kernel_regularizer=l2(weight_decay)\n", + " ),\n", + " keras.layers.Dense(num_classes, activation=\"softmax\"),\n", + " ]\n", + " )\n", + " optimizer = SGD(learning_rate=0.1)\n", + " model.compile(\n", + " loss=\"categorical_crossentropy\", optimizer=optimizer, metrics=[\"accuracy\"]\n", + " )\n", + "\n", + " return model\n", + "\n" + ] + }, + { + "cell_type": "code", + "source": [ + "def cifar10(num_classes, input_shape):\n", + " \"\"\"Prepare the CIFAR-10.\n", + "\n", + " This method considers CIFAR-10 for creating both train and test sets. The sets are\n", + " already normalized.\n", + " \"\"\"\n", + " print(f\">>> [Dataset] Loading CIFAR-10. {num_classes} | {input_shape}.\")\n", + " (x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()\n", + " x_train = x_train.astype(\"float32\") / 255\n", + " x_test = x_test.astype(\"float32\") / 255\n", + " input_shape = x_train.shape[1:]\n", + " num_classes = len(np.unique(y_train))\n", + "\n", + " return x_train, y_train, x_test, y_test, input_shape, num_classes" + ], + "metadata": { + "id": "vuQykx1uSXHk" + }, + "execution_count": 17, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "FMph7H-qbHHR", + "outputId": "45cf4a68-7054-460e-bcd7-c353338dc387" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + ">>> [Dataset] Loading CIFAR-10. 10 | (32, 32, 3).\n" + ] + } + ], + "source": [ + "x_train, y_train, x_test, y_test, input_shape,num_classes = cifar10(10, (32,32,3))\n" + ] + }, + { + "cell_type": "code", + "source": [ + "EPOCHS=350\n", + "BATCH_SIZE=128" + ], + "metadata": { + "id": "AD2qsybwX6uR" + }, + "execution_count": 19, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "---" + ], + "metadata": { + "id": "531ZRrY2SY85" + } + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "id": "1BO5D4ZBbJJo" + }, + "outputs": [], + "source": [ + "model = tf_example(input_shape, num_classes)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "8DMMAgw6bK2C", + "outputId": "9c1203a2-7152-4c25-dc1c-1c58b1cf8b8b" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Epoch 1/350\n", + "391/391 [==============================] - 8s 18ms/step - loss: 4.8242 - accuracy: 0.2914\n", + "Epoch 2/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 3.0276 - accuracy: 0.4814\n", + "Epoch 3/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 2.1395 - accuracy: 0.5609\n", + "Epoch 4/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 1.6463 - accuracy: 0.6129\n", + "Epoch 5/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 1.3656 - accuracy: 0.6504\n", + "Epoch 6/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 1.1851 - accuracy: 0.6868\n", + "Epoch 7/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 1.0698 - accuracy: 0.7147\n", + "Epoch 8/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.9918 - accuracy: 0.7350\n", + "Epoch 9/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.9465 - accuracy: 0.7551\n", + "Epoch 10/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.8991 - accuracy: 0.7747\n", + "Epoch 11/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.8534 - accuracy: 0.7971\n", + "Epoch 12/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.8305 - accuracy: 0.8111\n", + "Epoch 13/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.8070 - accuracy: 0.8265\n", + "Epoch 14/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.7805 - accuracy: 0.8434\n", + "Epoch 15/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.7672 - accuracy: 0.8527\n", + "Epoch 16/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.7504 - accuracy: 0.8647\n", + "Epoch 17/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.7418 - accuracy: 0.8715\n", + "Epoch 18/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.7244 - accuracy: 0.8819\n", + "Epoch 19/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.7205 - accuracy: 0.8871\n", + "Epoch 20/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.7032 - accuracy: 0.8966\n", + "Epoch 21/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6965 - accuracy: 0.8999\n", + "Epoch 22/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6998 - accuracy: 0.9026\n", + "Epoch 23/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6952 - accuracy: 0.9065\n", + "Epoch 24/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6795 - accuracy: 0.9120\n", + "Epoch 25/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6913 - accuracy: 0.9100\n", + "Epoch 26/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6822 - accuracy: 0.9144\n", + "Epoch 27/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6773 - accuracy: 0.9174\n", + "Epoch 28/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6885 - accuracy: 0.9155\n", + "Epoch 29/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6588 - accuracy: 0.9239\n", + "Epoch 30/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6631 - accuracy: 0.9230\n", + "Epoch 31/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6819 - accuracy: 0.9193\n", + "Epoch 32/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6557 - accuracy: 0.9271\n", + "Epoch 33/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6806 - accuracy: 0.9224\n", + "Epoch 34/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6525 - accuracy: 0.9299\n", + "Epoch 35/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6500 - accuracy: 0.9303\n", + "Epoch 36/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6701 - accuracy: 0.9234\n", + "Epoch 37/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6627 - accuracy: 0.9297\n", + "Epoch 38/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6507 - accuracy: 0.9321\n", + "Epoch 39/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6497 - accuracy: 0.9323\n", + "Epoch 40/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6593 - accuracy: 0.9304\n", + "Epoch 41/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6528 - accuracy: 0.9325\n", + "Epoch 42/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6294 - accuracy: 0.9365\n", + "Epoch 43/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6596 - accuracy: 0.9304\n", + "Epoch 44/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6493 - accuracy: 0.9343\n", + "Epoch 45/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6440 - accuracy: 0.9351\n", + "Epoch 46/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6233 - accuracy: 0.9392\n", + "Epoch 47/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6631 - accuracy: 0.9301\n", + "Epoch 48/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6341 - accuracy: 0.9397\n", + "Epoch 49/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6440 - accuracy: 0.9351\n", + "Epoch 50/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6540 - accuracy: 0.9354\n", + "Epoch 51/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6371 - accuracy: 0.9407\n", + "Epoch 52/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6464 - accuracy: 0.9373\n", + "Epoch 53/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6489 - accuracy: 0.9371\n", + "Epoch 54/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6471 - accuracy: 0.9386\n", + "Epoch 55/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6342 - accuracy: 0.9414\n", + "Epoch 56/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6447 - accuracy: 0.9379\n", + "Epoch 57/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6291 - accuracy: 0.9431\n", + "Epoch 58/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6447 - accuracy: 0.9376\n", + "Epoch 59/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6493 - accuracy: 0.9401\n", + "Epoch 60/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6317 - accuracy: 0.9425\n", + "Epoch 61/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6179 - accuracy: 0.9450\n", + "Epoch 62/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6549 - accuracy: 0.9370\n", + "Epoch 63/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6333 - accuracy: 0.9449\n", + "Epoch 64/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6345 - accuracy: 0.9409\n", + "Epoch 65/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6320 - accuracy: 0.9440\n", + "Epoch 66/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6361 - accuracy: 0.9423\n", + "Epoch 67/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6285 - accuracy: 0.9444\n", + "Epoch 68/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6324 - accuracy: 0.9427\n", + "Epoch 69/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6427 - accuracy: 0.9397\n", + "Epoch 70/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6429 - accuracy: 0.9436\n", + "Epoch 71/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6226 - accuracy: 0.9465\n", + "Epoch 72/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6406 - accuracy: 0.9411\n", + "Epoch 73/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6197 - accuracy: 0.9470\n", + "Epoch 74/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6285 - accuracy: 0.9434\n", + "Epoch 75/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6307 - accuracy: 0.9447\n", + "Epoch 76/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6243 - accuracy: 0.9465\n", + "Epoch 77/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6274 - accuracy: 0.9468\n", + "Epoch 78/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6397 - accuracy: 0.9432\n", + "Epoch 79/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6282 - accuracy: 0.9468\n", + "Epoch 80/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6408 - accuracy: 0.9434\n", + "Epoch 81/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6368 - accuracy: 0.9468\n", + "Epoch 82/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6158 - accuracy: 0.9499\n", + "Epoch 83/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6100 - accuracy: 0.9478\n", + "Epoch 84/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6321 - accuracy: 0.9429\n", + "Epoch 85/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6260 - accuracy: 0.9477\n", + "Epoch 86/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6235 - accuracy: 0.9463\n", + "Epoch 87/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6186 - accuracy: 0.9493\n", + "Epoch 88/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6155 - accuracy: 0.9481\n", + "Epoch 89/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6702 - accuracy: 0.9374\n", + "Epoch 90/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6188 - accuracy: 0.9502\n", + "Epoch 91/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6410 - accuracy: 0.9439\n", + "Epoch 92/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6052 - accuracy: 0.9528\n", + "Epoch 93/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6349 - accuracy: 0.9431\n", + "Epoch 94/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6216 - accuracy: 0.9486\n", + "Epoch 95/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6128 - accuracy: 0.9497\n", + "Epoch 96/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6286 - accuracy: 0.9469\n", + "Epoch 97/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6095 - accuracy: 0.9515\n", + "Epoch 98/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6124 - accuracy: 0.9487\n", + "Epoch 99/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6267 - accuracy: 0.9482\n", + "Epoch 100/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6323 - accuracy: 0.9459\n", + "Epoch 101/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6116 - accuracy: 0.9507\n", + "Epoch 102/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6192 - accuracy: 0.9478\n", + "Epoch 103/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6229 - accuracy: 0.9482\n", + "Epoch 104/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6261 - accuracy: 0.9486\n", + "Epoch 105/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6140 - accuracy: 0.9521\n", + "Epoch 106/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6256 - accuracy: 0.9476\n", + "Epoch 107/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6118 - accuracy: 0.9525\n", + "Epoch 108/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6064 - accuracy: 0.9502\n", + "Epoch 109/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6161 - accuracy: 0.9487\n", + "Epoch 110/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6105 - accuracy: 0.9513\n", + "Epoch 111/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6302 - accuracy: 0.9468\n", + "Epoch 112/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6022 - accuracy: 0.9534\n", + "Epoch 113/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5993 - accuracy: 0.9518\n", + "Epoch 114/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6260 - accuracy: 0.9462\n", + "Epoch 115/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6026 - accuracy: 0.9538\n", + "Epoch 116/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6144 - accuracy: 0.9499\n", + "Epoch 117/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6054 - accuracy: 0.9516\n", + "Epoch 118/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6122 - accuracy: 0.9504\n", + "Epoch 119/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6187 - accuracy: 0.9506\n", + "Epoch 120/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6030 - accuracy: 0.9524\n", + "Epoch 121/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6078 - accuracy: 0.9513\n", + "Epoch 122/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6113 - accuracy: 0.9503\n", + "Epoch 123/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6080 - accuracy: 0.9525\n", + "Epoch 124/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5991 - accuracy: 0.9539\n", + "Epoch 125/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5985 - accuracy: 0.9529\n", + "Epoch 126/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6103 - accuracy: 0.9509\n", + "Epoch 127/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5947 - accuracy: 0.9557\n", + "Epoch 128/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5945 - accuracy: 0.9532\n", + "Epoch 129/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6059 - accuracy: 0.9520\n", + "Epoch 130/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6076 - accuracy: 0.9517\n", + "Epoch 131/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6134 - accuracy: 0.9520\n", + "Epoch 132/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5950 - accuracy: 0.9546\n", + "Epoch 133/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5881 - accuracy: 0.9557\n", + "Epoch 134/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6095 - accuracy: 0.9494\n", + "Epoch 135/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6116 - accuracy: 0.9537\n", + "Epoch 136/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5860 - accuracy: 0.9554\n", + "Epoch 137/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6058 - accuracy: 0.9519\n", + "Epoch 138/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6043 - accuracy: 0.9542\n", + "Epoch 139/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5921 - accuracy: 0.9556\n", + "Epoch 140/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5983 - accuracy: 0.9530\n", + "Epoch 141/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5987 - accuracy: 0.9537\n", + "Epoch 142/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5983 - accuracy: 0.9544\n", + "Epoch 143/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5734 - accuracy: 0.9576\n", + "Epoch 144/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5895 - accuracy: 0.9534\n", + "Epoch 145/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6068 - accuracy: 0.9519\n", + "Epoch 146/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5973 - accuracy: 0.9548\n", + "Epoch 147/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5786 - accuracy: 0.9566\n", + "Epoch 148/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5833 - accuracy: 0.9547\n", + "Epoch 149/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6056 - accuracy: 0.9511\n", + "Epoch 150/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6060 - accuracy: 0.9517\n", + "Epoch 151/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5907 - accuracy: 0.9567\n", + "Epoch 152/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5922 - accuracy: 0.9541\n", + "Epoch 153/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5961 - accuracy: 0.9527\n", + "Epoch 154/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5878 - accuracy: 0.9580\n", + "Epoch 155/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5790 - accuracy: 0.9580\n", + "Epoch 156/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5977 - accuracy: 0.9523\n", + "Epoch 157/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5968 - accuracy: 0.9540\n", + "Epoch 158/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5950 - accuracy: 0.9547\n", + "Epoch 159/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5907 - accuracy: 0.9554\n", + "Epoch 160/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5817 - accuracy: 0.9560\n", + "Epoch 161/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5962 - accuracy: 0.9536\n", + "Epoch 162/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5876 - accuracy: 0.9572\n", + "Epoch 163/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5818 - accuracy: 0.9558\n", + "Epoch 164/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5896 - accuracy: 0.9541\n", + "Epoch 165/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5915 - accuracy: 0.9552\n", + "Epoch 166/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5928 - accuracy: 0.9555\n", + "Epoch 167/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5773 - accuracy: 0.9576\n", + "Epoch 168/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5817 - accuracy: 0.9560\n", + "Epoch 169/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5817 - accuracy: 0.9563\n", + "Epoch 170/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5877 - accuracy: 0.9565\n", + "Epoch 171/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5893 - accuracy: 0.9554\n", + "Epoch 172/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5946 - accuracy: 0.9543\n", + "Epoch 173/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5841 - accuracy: 0.9571\n", + "Epoch 174/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5745 - accuracy: 0.9598\n", + "Epoch 175/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5715 - accuracy: 0.9580\n", + "Epoch 176/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5809 - accuracy: 0.9552\n", + "Epoch 177/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5845 - accuracy: 0.9557\n", + "Epoch 178/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5720 - accuracy: 0.9591\n", + "Epoch 179/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5901 - accuracy: 0.9541\n", + "Epoch 180/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5667 - accuracy: 0.9608\n", + "Epoch 181/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5857 - accuracy: 0.9552\n", + "Epoch 182/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5694 - accuracy: 0.9613\n", + "Epoch 183/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5732 - accuracy: 0.9574\n", + "Epoch 184/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5861 - accuracy: 0.9562\n", + "Epoch 185/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5737 - accuracy: 0.9580\n", + "Epoch 186/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5816 - accuracy: 0.9584\n", + "Epoch 187/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5566 - accuracy: 0.9602\n", + "Epoch 188/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5664 - accuracy: 0.9576\n", + "Epoch 189/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5911 - accuracy: 0.9535\n", + "Epoch 190/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5742 - accuracy: 0.9595\n", + "Epoch 191/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5748 - accuracy: 0.9559\n", + "Epoch 192/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5606 - accuracy: 0.9604\n", + "Epoch 193/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6116 - accuracy: 0.9508\n", + "Epoch 194/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5868 - accuracy: 0.9591\n", + "Epoch 195/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5400 - accuracy: 0.9650\n", + "Epoch 196/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5624 - accuracy: 0.9574\n", + "Epoch 197/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5864 - accuracy: 0.9554\n", + "Epoch 198/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5773 - accuracy: 0.9585\n", + "Epoch 199/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5699 - accuracy: 0.9580\n", + "Epoch 200/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5703 - accuracy: 0.9595\n", + "Epoch 201/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5723 - accuracy: 0.9601\n", + "Epoch 202/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5641 - accuracy: 0.9591\n", + "Epoch 203/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5812 - accuracy: 0.9565\n", + "Epoch 204/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5653 - accuracy: 0.9612\n", + "Epoch 205/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5697 - accuracy: 0.9592\n", + "Epoch 206/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5726 - accuracy: 0.9591\n", + "Epoch 207/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5570 - accuracy: 0.9612\n", + "Epoch 208/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5598 - accuracy: 0.9599\n", + "Epoch 209/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5709 - accuracy: 0.9578\n", + "Epoch 210/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5836 - accuracy: 0.9563\n", + "Epoch 211/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5621 - accuracy: 0.9613\n", + "Epoch 212/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5722 - accuracy: 0.9582\n", + "Epoch 213/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5483 - accuracy: 0.9624\n", + "Epoch 214/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5708 - accuracy: 0.9563\n", + "Epoch 215/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5746 - accuracy: 0.9572\n", + "Epoch 216/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5840 - accuracy: 0.9584\n", + "Epoch 217/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5545 - accuracy: 0.9623\n", + "Epoch 218/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5402 - accuracy: 0.9628\n", + "Epoch 219/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5546 - accuracy: 0.9591\n", + "Epoch 220/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5762 - accuracy: 0.9552\n", + "Epoch 221/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5596 - accuracy: 0.9604\n", + "Epoch 222/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5548 - accuracy: 0.9610\n", + "Epoch 223/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5586 - accuracy: 0.9608\n", + "Epoch 224/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5504 - accuracy: 0.9612\n", + "Epoch 225/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5496 - accuracy: 0.9607\n", + "Epoch 226/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5763 - accuracy: 0.9562\n", + "Epoch 227/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5664 - accuracy: 0.9602\n", + "Epoch 228/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5404 - accuracy: 0.9648\n", + "Epoch 229/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5603 - accuracy: 0.9580\n", + "Epoch 230/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5574 - accuracy: 0.9610\n", + "Epoch 231/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5575 - accuracy: 0.9586\n", + "Epoch 232/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5660 - accuracy: 0.9585\n", + "Epoch 233/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5427 - accuracy: 0.9640\n", + "Epoch 234/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5468 - accuracy: 0.9611\n", + "Epoch 235/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5678 - accuracy: 0.9581\n", + "Epoch 236/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5472 - accuracy: 0.9622\n", + "Epoch 237/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5561 - accuracy: 0.9601\n", + "Epoch 238/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5471 - accuracy: 0.9621\n", + "Epoch 239/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5539 - accuracy: 0.9601\n", + "Epoch 240/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5492 - accuracy: 0.9619\n", + "Epoch 241/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5674 - accuracy: 0.9581\n", + "Epoch 242/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5451 - accuracy: 0.9618\n", + "Epoch 243/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5280 - accuracy: 0.9646\n", + "Epoch 244/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5628 - accuracy: 0.9579\n", + "Epoch 245/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5504 - accuracy: 0.9625\n", + "Epoch 246/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5284 - accuracy: 0.9647\n", + "Epoch 247/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5277 - accuracy: 0.9629\n", + "Epoch 248/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5490 - accuracy: 0.9599\n", + "Epoch 249/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5582 - accuracy: 0.9601\n", + "Epoch 250/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5701 - accuracy: 0.9587\n", + "Epoch 251/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5274 - accuracy: 0.9664\n", + "Epoch 252/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5343 - accuracy: 0.9618\n", + "Epoch 253/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5601 - accuracy: 0.9586\n", + "Epoch 254/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5544 - accuracy: 0.9608\n", + "Epoch 255/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5447 - accuracy: 0.9631\n", + "Epoch 256/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5355 - accuracy: 0.9634\n", + "Epoch 257/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5321 - accuracy: 0.9625\n", + "Epoch 258/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5554 - accuracy: 0.9593\n", + "Epoch 259/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5533 - accuracy: 0.9608\n", + "Epoch 260/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5511 - accuracy: 0.9618\n", + "Epoch 261/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5180 - accuracy: 0.9667\n", + "Epoch 262/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5495 - accuracy: 0.9582\n", + "Epoch 263/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5347 - accuracy: 0.9640\n", + "Epoch 264/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5289 - accuracy: 0.9639\n", + "Epoch 265/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5340 - accuracy: 0.9623\n", + "Epoch 266/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5463 - accuracy: 0.9604\n", + "Epoch 267/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5383 - accuracy: 0.9639\n", + "Epoch 268/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5421 - accuracy: 0.9614\n", + "Epoch 269/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5213 - accuracy: 0.9651\n", + "Epoch 270/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5470 - accuracy: 0.9599\n", + "Epoch 271/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5388 - accuracy: 0.9634\n", + "Epoch 272/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5384 - accuracy: 0.9630\n", + "Epoch 273/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5326 - accuracy: 0.9638\n", + "Epoch 274/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5442 - accuracy: 0.9609\n", + "Epoch 275/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5384 - accuracy: 0.9634\n", + "Epoch 276/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5302 - accuracy: 0.9627\n", + "Epoch 277/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5403 - accuracy: 0.9617\n", + "Epoch 278/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5325 - accuracy: 0.9647\n", + "Epoch 279/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5370 - accuracy: 0.9619\n", + "Epoch 280/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5357 - accuracy: 0.9640\n", + "Epoch 281/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5287 - accuracy: 0.9640\n", + "Epoch 282/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5315 - accuracy: 0.9613\n", + "Epoch 283/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5361 - accuracy: 0.9649\n", + "Epoch 284/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5382 - accuracy: 0.9614\n", + "Epoch 285/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5313 - accuracy: 0.9637\n", + "Epoch 286/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5414 - accuracy: 0.9618\n", + "Epoch 287/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5197 - accuracy: 0.9667\n", + "Epoch 288/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5287 - accuracy: 0.9613\n", + "Epoch 289/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5433 - accuracy: 0.9610\n", + "Epoch 290/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5371 - accuracy: 0.9637\n", + "Epoch 291/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5274 - accuracy: 0.9636\n", + "Epoch 292/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5302 - accuracy: 0.9638\n", + "Epoch 293/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5418 - accuracy: 0.9611\n", + "Epoch 294/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5264 - accuracy: 0.9648\n", + "Epoch 295/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5397 - accuracy: 0.9614\n", + "Epoch 296/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5217 - accuracy: 0.9652\n", + "Epoch 297/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5193 - accuracy: 0.9648\n", + "Epoch 298/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5296 - accuracy: 0.9643\n", + "Epoch 299/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5312 - accuracy: 0.9621\n", + "Epoch 300/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5321 - accuracy: 0.9632\n", + "Epoch 301/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5151 - accuracy: 0.9664\n", + "Epoch 302/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5239 - accuracy: 0.9634\n", + "Epoch 303/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5264 - accuracy: 0.9640\n", + "Epoch 304/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5168 - accuracy: 0.9652\n", + "Epoch 305/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5238 - accuracy: 0.9649\n", + "Epoch 306/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5178 - accuracy: 0.9635\n", + "Epoch 307/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5119 - accuracy: 0.9650\n", + "Epoch 308/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5237 - accuracy: 0.9634\n", + "Epoch 309/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5284 - accuracy: 0.9635\n", + "Epoch 310/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5121 - accuracy: 0.9660\n", + "Epoch 311/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5394 - accuracy: 0.9599\n", + "Epoch 312/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5048 - accuracy: 0.9697\n", + "Epoch 313/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5069 - accuracy: 0.9650\n", + "Epoch 314/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5089 - accuracy: 0.9657\n", + "Epoch 315/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5242 - accuracy: 0.9627\n", + "Epoch 316/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5258 - accuracy: 0.9638\n", + "Epoch 317/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5221 - accuracy: 0.9643\n", + "Epoch 318/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5045 - accuracy: 0.9666\n", + "Epoch 319/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5195 - accuracy: 0.9652\n", + "Epoch 320/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5005 - accuracy: 0.9680\n", + "Epoch 321/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5238 - accuracy: 0.9615\n", + "Epoch 322/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5321 - accuracy: 0.9618\n", + "Epoch 323/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5160 - accuracy: 0.9674\n", + "Epoch 324/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5245 - accuracy: 0.9628\n", + "Epoch 325/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5109 - accuracy: 0.9669\n", + "Epoch 326/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5138 - accuracy: 0.9656\n", + "Epoch 327/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.4976 - accuracy: 0.9667\n", + "Epoch 328/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5200 - accuracy: 0.9624\n", + "Epoch 329/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.4939 - accuracy: 0.9700\n", + "Epoch 330/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.4973 - accuracy: 0.9646\n", + "Epoch 331/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5258 - accuracy: 0.9619\n", + "Epoch 332/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5384 - accuracy: 0.9623\n", + "Epoch 333/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5265 - accuracy: 0.9655\n", + "Epoch 334/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5038 - accuracy: 0.9678\n", + "Epoch 335/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5162 - accuracy: 0.9643\n", + "Epoch 336/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5066 - accuracy: 0.9665\n", + "Epoch 337/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5063 - accuracy: 0.9660\n", + "Epoch 338/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5078 - accuracy: 0.9658\n", + "Epoch 339/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5310 - accuracy: 0.9632\n", + "Epoch 340/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.4861 - accuracy: 0.9703\n", + "Epoch 341/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5143 - accuracy: 0.9631\n", + "Epoch 342/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5199 - accuracy: 0.9637\n", + "Epoch 343/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.4992 - accuracy: 0.9685\n", + "Epoch 344/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5109 - accuracy: 0.9644\n", + "Epoch 345/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5066 - accuracy: 0.9657\n", + "Epoch 346/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5142 - accuracy: 0.9651\n", + "Epoch 347/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5092 - accuracy: 0.9649\n", + "Epoch 348/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5188 - accuracy: 0.9636\n", + "Epoch 349/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5069 - accuracy: 0.9677\n", + "Epoch 350/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.4872 - accuracy: 0.9686\n" + ] + } + ], + "source": [ + "history = model.fit(x_train, to_categorical(y_train, num_classes), epochs=EPOCHS, batch_size=BATCH_SIZE)" + ] + }, + { + "cell_type": "code", + "source": [ + "loss = history.history['loss']\n", + "epochs = range(1, len(loss) + 1)\n", + "\n", + "plt.plot(epochs, loss, 'b', label='Training Loss')\n", + "plt.title('Training Loss')\n", + "plt.xlabel('Epochs')\n", + "plt.ylabel('Loss')\n", + "plt.legend()\n", + "plt.show()\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 472 + }, + "id": "6lrFuQrNRCyv", + "outputId": "3bc66200-18f3-483e-8c8c-3b44072fe7bb" + }, + "execution_count": 22, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ], + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAioAAAHHCAYAAACRAnNyAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABEw0lEQVR4nO3deXhU1f3H8c+EkCGBLCyBhB0BWQUREEEBKwhEpYD4UykqoIUKqNBKq6gsai2i1lq1xbVQkUqVAlIEBVRcEBUUEBARlE22yJaFJUByfn+czoSBTDaSnEnyfj3PPJm5986dc++M3g/fc+69HmOMEQAAQAgKc90AAACAYAgqAAAgZBFUAABAyCKoAACAkEVQAQAAIYugAgAAQhZBBQAAhCyCCgAACFkEFQAAELIIKgDybejQoWrYsGGh3jt58mR5PJ6ibRCAMo+gApQBHo8nX4/ly5e7bqoTQ4cOVZUqVVw3A0AheLjXD1D6vf766wGvX3vtNS1dulQzZ84MmH711VerVq1ahf6cU6dOKSsrS16vt8DvPX36tE6fPq1KlSoV+vMLa+jQoZozZ47S09NL/LMBnJ9w1w0AcP5uueWWgNeff/65li5des70sx07dkxRUVH5/pyKFSsWqn2SFB4ervBw/pcDoGDo+gHKiSuvvFKtW7fWV199pW7duikqKkoPPPCAJOntt9/Wtddeq9q1a8vr9apx48Z69NFHlZmZGbCOs8eobN++XR6PR0899ZReeuklNW7cWF6vVx07dtSqVasC3pvTGBWPx6O77rpL8+fPV+vWreX1etWqVSu9++6757R/+fLl6tChgypVqqTGjRvrxRdfLPJxL2+99Zbat2+vyMhI1ahRQ7fccot2794dsMy+ffs0bNgw1a1bV16vV4mJierXr5+2b9/uX2b16tXq3bu3atSoocjISDVq1Ei33357kbUTKE/45w1Qjhw8eFBJSUm6+eabdcstt/i7gWbMmKEqVarod7/7napUqaIPPvhAEydOVGpqqp588sk81/uvf/1LaWlp+s1vfiOPx6MnnnhC119/vX788cc8qzCffvqp5s6dq1GjRik6OlrPPvusBg4cqJ07d6p69eqSpDVr1qhPnz5KTEzUww8/rMzMTD3yyCOKj48//53yPzNmzNCwYcPUsWNHTZkyRfv379df//pXrVixQmvWrFFcXJwkaeDAgdq4caPuvvtuNWzYUMnJyVq6dKl27tzpf92rVy/Fx8fr/vvvV1xcnLZv3665c+cWWVuBcsUAKHNGjx5tzv7Pu3v37kaSeeGFF85Z/tixY+dM+81vfmOioqLMiRMn/NOGDBliGjRo4H+9bds2I8lUr17dHDp0yD/97bffNpLMf//7X/+0SZMmndMmSSYiIsJs3brVP23dunVGknnuuef80/r27WuioqLM7t27/dO2bNliwsPDz1lnToYMGWIqV64cdP7JkydNzZo1TevWrc3x48f90xcuXGgkmYkTJxpjjDl8+LCRZJ588smg65o3b56RZFatWpVnuwDkja4foBzxer0aNmzYOdMjIyP9z9PS0nTgwAF17dpVx44d03fffZfnem+66SZVrVrV/7pr166SpB9//DHP9/bs2VONGzf2v27Tpo1iYmL8783MzNSyZcvUv39/1a5d279ckyZNlJSUlOf682P16tVKTk7WqFGjAgb7XnvttWrevLneeecdSXY/RUREaPny5Tp8+HCO6/JVXhYuXKhTp04VSfuA8oygApQjderUUURExDnTN27cqAEDBig2NlYxMTGKj4/3D8RNSUnJc73169cPeO0LLcEO5rm91/d+33uTk5N1/PhxNWnS5JzlcppWGDt27JAkNWvW7Jx5zZs398/3er2aOnWqFi9erFq1aqlbt2564okntG/fPv/y3bt318CBA/Xwww+rRo0a6tevn6ZPn66MjIwiaStQ3hBUgHLkzMqJz5EjR9S9e3etW7dOjzzyiP773/9q6dKlmjp1qiQpKysrz/VWqFAhx+kmH1c/OJ/3ujB27Fh9//33mjJliipVqqQJEyaoRYsWWrNmjSQ7QHjOnDlauXKl7rrrLu3evVu333672rdvz+nRQCEQVIBybvny5Tp48KBmzJihMWPG6LrrrlPPnj0DunJcqlmzpipVqqStW7eeMy+naYXRoEEDSdLmzZvPmbd582b/fJ/GjRvr3nvv1ZIlS7RhwwadPHlSf/7znwOWueyyy/TYY49p9erVmjVrljZu3KjZs2cXSXuB8oSgApRzvorGmRWMkydP6u9//7urJgWoUKGCevbsqfnz52vPnj3+6Vu3btXixYuL5DM6dOigmjVr6oUXXgjoolm8eLE2bdqka6+9VpK97syJEycC3tu4cWNFR0f733f48OFzqkEXX3yxJNH9AxQCpycD5VyXLl1UtWpVDRkyRPfcc488Ho9mzpwZUl0vkydP1pIlS3T55Zdr5MiRyszM1PPPP6/WrVtr7dq1+VrHqVOn9Mc//vGc6dWqVdOoUaM0depUDRs2TN27d9egQYP8pyc3bNhQv/3tbyVJ33//vXr06KEbb7xRLVu2VHh4uObNm6f9+/fr5ptvliT985//1N///ncNGDBAjRs3Vlpaml5++WXFxMTommuuKbJ9ApQXBBWgnKtevboWLlyoe++9Vw899JCqVq2qW265RT169FDv3r1dN0+S1L59ey1evFjjxo3ThAkTVK9ePT3yyCPatGlTvs5KkmyVaMKECedMb9y4sUaNGqWhQ4cqKipKjz/+uO677z5VrlxZAwYM0NSpU/1n8tSrV0+DBg3S+++/r5kzZyo8PFzNmzfXm2++qYEDB0qyg2m//PJLzZ49W/v371dsbKwuvfRSzZo1S40aNSqyfQKUF9zrB0Cp1b9/f23cuFFbtmxx3RQAxYQxKgBKhePHjwe83rJlixYtWqQrr7zSTYMAlAgqKgBKhcTERA0dOlQXXHCBduzYoWnTpikjI0Nr1qxR06ZNXTcPQDFhjAqAUqFPnz564403tG/fPnm9XnXu3Fl/+tOfCClAGUdFBQAAhCzGqAAAgJBFUAEAACHL6RiVyZMn6+GHHw6Y1qxZs3xfFyErK0t79uxRdHS0PB5PcTQRAAAUMWOM0tLSVLt2bYWF5V4zcT6YtlWrVlq2bJn/dXh4/pu0Z88e1atXrziaBQAAitmuXbtUt27dXJdxHlTCw8OVkJBQqPdGR0dLshsaExNTlM0CAADFJDU1VfXq1fMfx3PjPKhs2bJFtWvXVqVKldS5c2dNmTJF9evXz3HZjIyMgJt6paWlSZJiYmIIKgAAlDL5GbbhdDBtp06dNGPGDL377ruaNm2atm3bpq5du/oDyNmmTJmi2NhY/4NuHwAAyraQuo7KkSNH1KBBAz399NO64447zpl/dkXFVzpKSUmhogIAQCmRmpqq2NjYfB2/nXf9nCkuLk4XXnihtm7dmuN8r9crr9dbwq0CAACuhFRQSU9P1w8//KBbb73VdVMAALnIysrSyZMnXTcDIapixYqqUKFCkazLaVAZN26c+vbtqwYNGmjPnj2aNGmSKlSooEGDBrlsFgAgFydPntS2bduUlZXluikIYXFxcUpISDjv65w5DSo//fSTBg0apIMHDyo+Pl5XXHGFPv/8c8XHx7tsFgAgCGOM9u7dqwoVKqhevXp5XqwL5Y8xRseOHVNycrIke+fz8+E0qMyePdvlxwMACuj06dM6duyYateuraioKNfNQYiKjIyUJCUnJ6tmzZrn1Q1EFAYA5FtmZqYkKSIiwnFLEOp8QfbUqVPntR6CCgCgwLi/GvJSVL8RggoAAAhZBBUAAAqhYcOGeuaZZ/K9/PLly+XxeHTkyJFia1NZRFABAJRpHo8n18fkyZMLtd5Vq1ZpxIgR+V6+S5cu2rt3r2JjYwv1eflV1gJRSF3wLVQcPSodOCB5vVIhb+wMAAgRe/fu9T//97//rYkTJ2rz5s3+aVWqVPE/N8YoMzNT4eF5Hx4LeimNiIgIJXBQKTAqKjlYsEBq2FC65RbXLQEAnK+EhAT/IzY2Vh6Px//6u+++U3R0tBYvXqz27dvL6/Xq008/1Q8//KB+/fqpVq1aqlKlijp27Khly5YFrPfsrh+Px6NXXnlFAwYMUFRUlJo2baoFCxb4559d6ZgxY4bi4uL03nvvqUWLFqpSpYr69OkTEKxOnz6te+65R3Fxcapevbruu+8+DRkyRP379y/0/jh8+LBuu+02Va1aVVFRUUpKStKWLVv883fs2KG+ffuqatWqqly5slq1aqVFixb53zt48GDFx8crMjJSTZs21fTp0wvdlvwgqOTAd/2i/52FBwAIwhhbhXbxKMpb6t5///16/PHHtWnTJrVp00bp6em65ppr9P7772vNmjXq06eP+vbtq507d+a6nocfflg33nijvvnmG11zzTUaPHiwDh06FHT5Y8eO6amnntLMmTP18ccfa+fOnRo3bpx//tSpUzVr1ixNnz5dK1asUGpqqubPn39e2zp06FCtXr1aCxYs0MqVK2WM0TXXXOM/jXj06NHKyMjQxx9/rPXr12vq1Kn+qtOECRP07bffavHixdq0aZOmTZumGjVqnFd78mRKsZSUFCPJpKSkFOl633rLGMmYbt2KdLUAUOodP37cfPvtt+b48ePGGGPS0+3/L1080tML3v7p06eb2NhY/+sPP/zQSDLz58/P872tWrUyzz33nP91gwYNzF/+8hf/a0nmoYce8r9OT083kszixYsDPuvw4cP+tkgyW7du9b/nb3/7m6lVq5b/da1atcyTTz7pf3369GlTv359069fv6DtPPtzzvT9998bSWbFihX+aQcOHDCRkZHmzTffNMYYc9FFF5nJkyfnuO6+ffuaYcOGBf3sM539WzlTQY7fVFRy4KuocBsLACgfOnToEPA6PT1d48aNU4sWLRQXF6cqVapo06ZNeVZU2rRp439euXJlxcTE+C8ln5OoqCg1btzY/zoxMdG/fEpKivbv369LL73UP79ChQpq3759gbbtTJs2bVJ4eLg6derkn1a9enU1a9ZMmzZtkiTdc889+uMf/6jLL79ckyZN0jfffONfduTIkZo9e7Yuvvhi/eEPf9Bnn31W6LbkF0ElB3T9AED+REVJ6eluHkV5Bf/KlSsHvB43bpzmzZunP/3pT/rkk0+0du1aXXTRRXneMbpixYoBrz0eT643b8xpeVOUfVqF8Otf/1o//vijbr31Vq1fv14dOnTQc889J0lKSkrSjh079Nvf/lZ79uxRjx49ArqqigNBJQe+WxJQUQGA3Hk8UuXKbh7FeXHcFStWaOjQoRowYIAuuugiJSQkaPv27cX3gTmIjY1VrVq1tGrVKv+0zMxMff3114VeZ4sWLXT69Gl98cUX/mkHDx7U5s2b1bJlS/+0evXq6c4779TcuXN177336uWXX/bPi4+P15AhQ/T666/rmWee0UsvvVTo9uQHpyfngK4fACjfmjZtqrlz56pv377yeDyaMGFCrpWR4nL33XdrypQpatKkiZo3b67nnntOhw8fztfl6devX6/o6Gj/a4/Ho7Zt26pfv34aPny4XnzxRUVHR+v+++9XnTp11K9fP0nS2LFjlZSUpAsvvFCHDx/Whx9+qBYtWkiSJk6cqPbt26tVq1bKyMjQwoUL/fOKC0ElB3T9AED59vTTT+v2229Xly5dVKNGDd13331KTU0t8Xbcd9992rdvn2677TZVqFBBI0aMUO/evfN1N+Ju3boFvK5QoYJOnz6t6dOna8yYMbruuut08uRJdevWTYsWLfJ3Q2VmZmr06NH66aefFBMToz59+ugvf/mLJHstmPHjx2v79u2KjIxU165dNXv27KLf8DN4jOvOsPOQmpqq2NhYpaSkKCYmpsjWu2SJ1Lu3dPHF0po1RbZaACj1Tpw4oW3btqlRo0aqVKmS6+aUO1lZWWrRooVuvPFGPfroo66bk6vcfisFOX5TUckBFRUAQCjYsWOHlixZou7duysjI0PPP/+8tm3bpl/96leum1ZiGEybAwbTAgBCQVhYmGbMmKGOHTvq8ssv1/r167Vs2bJiHxcSSqio5IDBtACAUFCvXj2tWLHCdTOcoqKSA7p+AAAIDQSVHND1AwC5K8XnYaCEFNVvhKCSA7p+ACBnvtNi87pCK3Ds2DFJ5159t6AYo5IDun4AIGfh4eGKiorSzz//rIoVKyosjH/vIpAxRseOHVNycrLi4uLydc2X3BBUckDXDwDkzOPxKDExUdu2bdOOHTtcNwchLC4uTgkJCee9HoJKDqioAEBwERERatq0Kd0/CKpixYrnXUnxIajkgIoKAOQuLCyMK9OiRNC5mAMG0wIAEBoIKjmg6wcAgNBAUMkBXT8AAIQGgkoO6PoBACA0EFRyQNcPAAChgaCSA7p+AAAIDQSVHFBRAQAgNBBUckBFBQCA0EBQyQGDaQEACA0ElRzQ9QMAQGggqOTgzNsTGOOuHQAAlHcElRyceddyun8AAHCHoJKDM4MK3T8AALhDUMnBmV0/VFQAAHCHoJIDKioAAIQGgkoOqKgAABAaCCo5YDAtAAChgaCSA7p+AAAIDQSVHND1AwBAaCCo5MDjyX5OUAEAwB2CShBcRh8AAPcIKkFwB2UAANwjqARBRQUAAPcIKkFQUQEAwD2CShC+igpBBQAAdwgqQdD1AwCAewSVIOj6AQDAPYJKEHT9AADgHkElCLp+AABwj6ASBF0/AAC4R1AJgooKAADuEVSCoKICAIB7BJUgGEwLAIB7BJUg6PoBAMA9gkoQdP0AAOAeQSUIun4AAHCPoBIEXT8AALhHUAmCrh8AANwjqARBRQUAAPcIKkFQUQEAwD2CShAMpgUAwL2QCSqPP/64PB6Pxo4d67opkuj6AQAgFIREUFm1apVefPFFtWnTxnVT/Oj6AQDAPedBJT09XYMHD9bLL7+sqlWrum6OH10/AAC45zyojB49Wtdee6169uzpuikB6PoBAMC9cJcfPnv2bH399ddatWpVvpbPyMhQRkaG/3VqampxNY2uHwAAQoCzisquXbs0ZswYzZo1S5UqVcrXe6ZMmaLY2Fj/o169esXWPioqAAC45yyofPXVV0pOTtYll1yi8PBwhYeH66OPPtKzzz6r8PBwZeaQEMaPH6+UlBT/Y9euXcXWPioqAAC456zrp0ePHlq/fn3AtGHDhql58+a67777VMGXFM7g9Xrl9XpLpH0MpgUAwD1nQSU6OlqtW7cOmFa5cmVVr179nOku0PUDAIB7zs/6CVV0/QAA4J7Ts37Otnz5ctdN8KPrBwAA96ioBEHXDwAA7hFUgqDrBwAA9wgqQVBRAQDAPYJKEFRUAABwj6ASBINpAQBwj6ASBF0/AAC4R1AJgq4fAADcI6gEQdcPAADuEVSCoOsHAAD3CCpB0PUDAIB7BJUgqKgAAOAeQSUIKioAALhHUAmCwbQAALhHUAmCrh8AANwjqARB1w8AAO4RVIKg6wcAAPcIKkHQ9QMAgHsElSDo+gEAwD2CShBUVAAAcI+gEgQVFQAA3COoBMFgWgAA3COoBEHXDwAA7hFUgqDrBwAA9wgqQdD1AwCAewSVIOj6AQDAPYJKEHT9AADgHkElCCoqAAC4R1AJgooKAADuEVSCYDAtAADuEVSCoOsHAAD3CCpB0PUDAIB7BJUg6PoBAMA9gkoQdP0AAOAeQSUIun4AAHCPoBIEFRUAANwjqARBRQUAAPcIKkEwmBYAAPcIKkHQ9QMAgHsElSDo+gEAwD2CShB0/QAA4B5BJQi6fgAAcI+gEgRdPwAAuEdQCYKKCgAA7hFUgqCiAgCAewSVIBhMCwCAewSVIOj6AQDAPYJKEHT9AADgHkElCLp+AABwj6ASBF0/AAC4R1AJgq4fAADcI6gEQUUFAAD3CCpBUFEBAMA9gkoQDKYFAMA9gkoQdP0AAOAeQSUIun4AAHCPoBIEXT8AALhHUAmCrh8AANwjqARB1w8AAO4RVIKgogIAgHsElSCoqAAA4B5BJQgG0wIA4B5BJQi6fgAAcI+gEgRdPwAAuEdQCYKuHwAA3COoBEHXDwAA7hFUgqDrBwAA9wgqQYSdsWcIKwAAuOE0qEybNk1t2rRRTEyMYmJi1LlzZy1evNhlk/x8FRWJoAIAgCtOg0rdunX1+OOP66uvvtLq1at11VVXqV+/ftq4caPLZkmiogIAQCjwGGOM60acqVq1anryySd1xx135LlsamqqYmNjlZKSopiYmCJtR3q6FB1tnx89KkVFFenqAQAotwpy/A4voTblKTMzU2+99ZaOHj2qzp0757hMRkaGMjIy/K9TU1OLrT0VK2Y/P3Wq2D4GAADkwvlg2vXr16tKlSryer268847NW/ePLVs2TLHZadMmaLY2Fj/o169esXWLoIKAADuOe/6OXnypHbu3KmUlBTNmTNHr7zyij766KMcw0pOFZV69eoVS9ePZAfUZmVJe/dKCQlFvnoAAMqlgnT9OA8qZ+vZs6caN26sF198Mc9li3OMiiR5vdLJk9LOnVIxFm8AAChXCnL8dt71c7asrKyAqolLvu4fun4AAHDD6WDa8ePHKykpSfXr11daWpr+9a9/afny5XrvvfdcNsuPoAIAgFtOg0pycrJuu+027d27V7GxsWrTpo3ee+89XX311S6b5UdQAQDALadB5dVXX3X58XkiqAAA4FbIjVEJJeH/i3GnT7ttBwAA5RVBJRdUVAAAcIugkguCCgAAbhFUckFQAQDALYJKLggqAAC4RVDJBYNpAQBwi6CSCyoqAAC4RVDJBUEFAAC3ChVUdu3apZ9++sn/+ssvv9TYsWP10ksvFVnDQgFBBQAAtwoVVH71q1/pww8/lCTt27dPV199tb788ks9+OCDeuSRR4q0gS4RVAAAcKtQQWXDhg269NJLJUlvvvmmWrdurc8++0yzZs3SjBkzirJ9TjGYFgAAtwoVVE6dOiWv1ytJWrZsmX75y19Kkpo3b669e/cWXesco6ICAIBbhQoqrVq10gsvvKBPPvlES5cuVZ8+fSRJe/bsUfXq1Yu0gS4RVAAAcKtQQWXq1Kl68cUXdeWVV2rQoEFq27atJGnBggX+LqGygKACAIBb4YV505VXXqkDBw4oNTVVVatW9U8fMWKEoqKiiqxxrhFUAABwq1AVlePHjysjI8MfUnbs2KFnnnlGmzdvVs2aNYu0gS4RVAAAcKtQQaVfv3567bXXJElHjhxRp06d9Oc//1n9+/fXtGnTirSBLnHWDwAAbhUqqHz99dfq2rWrJGnOnDmqVauWduzYoddee03PPvtskTbQJSoqAAC4VaigcuzYMUVHR0uSlixZouuvv15hYWG67LLLtGPHjiJtoEsEFQAA3CpUUGnSpInmz5+vXbt26b333lOvXr0kScnJyYqJiSnSBrpEUAEAwK1CBZWJEydq3LhxatiwoS699FJ17txZkq2utGvXrkgb6BJBBQAAtwp1evINN9ygK664Qnv37vVfQ0WSevTooQEDBhRZ41xjMC0AAG4VKqhIUkJCghISEvx3Ua5bt26ZutibREUFAADXCtX1k5WVpUceeUSxsbFq0KCBGjRooLi4OD366KPKysoq6jY6Q1ABAMCtQlVUHnzwQb366qt6/PHHdfnll0uSPv30U02ePFknTpzQY489VqSNdIWgAgCAW4UKKv/85z/1yiuv+O+aLElt2rRRnTp1NGrUKIIKAAAoEoXq+jl06JCaN29+zvTmzZvr0KFD592oUMFgWgAA3CpUUGnbtq2ef/75c6Y///zzatOmzXk3KlRQUQEAwK1Cdf088cQTuvbaa7Vs2TL/NVRWrlypXbt2adGiRUXaQJcIKgAAuFWoikr37t31/fffa8CAATpy5IiOHDmi66+/Xhs3btTMmTOLuo3OEFQAAHDLY4wxRbWydevW6ZJLLlFmZmZRrTJXqampio2NVUpKSrFcun/+fGnAAKlzZ+mzz4p89QAAlEsFOX4XqqJSXlBRAQDALYJKLjjrBwAAtwgquaCiAgCAWwU66+f666/Pdf6RI0fOpy0hh6ACAIBbBQoqsbGxec6/7bbbzqtBoYSgAgCAWwUKKtOnTy+udoQkggoAAG4xRiUXDKYFAMAtgkouqKgAAOAWQSUXBBUAANwiqOSCoAIAgFsElVwQVAAAcIugkoszB9MW3R2RAABAfhFUcuGrqEhSCd1nEQAAnIGgkoszgwrdPwAAlDyCSi4IKgAAuEVQyQVBBQAAtwgquahQQfJ47HOCCgAAJY+gkgcuow8AgDsElTxwLRUAANwhqOSBoAIAgDsElTwQVAAAcIegkgeCCgAA7hBU8sBgWgAA3CGo5IGKCgAA7hBU8kBQAQDAHYJKHggqAAC4Q1DJA0EFAAB3CCp5YDAtAADuEFTy4KuonDzpth0AAJRHBJU8eL32b0aG23YAAFAeEVTyEBlp/5444bYdAACURwSVPPiCyvHjbtsBAEB5RFDJQ6VK9i9BBQCAkkdQyQNdPwAAuENQyQNdPwAAuOM0qEyZMkUdO3ZUdHS0atasqf79+2vz5s0um3QOX9cPFRUAAEqe06Dy0UcfafTo0fr888+1dOlSnTp1Sr169dLRo0ddNisAFRUAANwJd/nh7777bsDrGTNmqGbNmvrqq6/UrVs3R60KRFABAMCdkBqjkpKSIkmqVq2a45Zko+sHAAB3nFZUzpSVlaWxY8fq8ssvV+vWrXNcJiMjQxlnXCI2NTW12NtFRQUAAHdCpqIyevRobdiwQbNnzw66zJQpUxQbG+t/1KtXr9jbRVABAMCdkAgqd911lxYuXKgPP/xQdevWDbrc+PHjlZKS4n/s2rWr2NtG1w8AAO447foxxujuu+/WvHnztHz5cjVq1CjX5b1er7y+uwSWECoqAAC44zSojB49Wv/617/09ttvKzo6Wvv27ZMkxcbGKtKXEBwjqAAA4I7Trp9p06YpJSVFV155pRITE/2Pf//73y6bFYCuHwAA3HHe9RPqqKgAAOBOSAymDWUEFQAA3CGo5IGuHwAA3CGo5OHMikop6KkCAKBMIajkwRdUsrKk06fdtgUAgPKGoJIHX9ePxDgVAABKGkElD16v5PHY5wQVAABKFkElDx4PA2oBAHCFoJIPvqBCRQUAgJJFUMkHrqUCAIAbBJV88AUVun4AAChZBJV8oOsHAAA3CCr5QNcPAABuEFTyga4fAADcIKjkA10/AAC4QVDJB7p+AABwg6CSD3T9AADgBkElH+j6AQDADYJKPlBRAQDADYJKPjBGBQAANwgq+UDXDwAAbhBU8oGuHwAA3CCo5EPlyvZvWprbdgAAUN4QVPIhNtb+TUlx2w4AAMobgko+EFQAAHCDoJIPcXH2L0EFAICSRVDJByoqAAC4QVDJB19QOXLEaTMAACh3CCr54AsqaWlSVpbbtgAAUJ4QVPLBF1SM4RRlAABKEkElHypVkrxe+5xxKgAAlByCSj4xTgUAgJJHUMknzvwBAKDkEVTyiaACAEDJI6jkE0EFAICSR1DJJ65OCwBAySOo5BODaQEAKHkElXyi6wcAgJJHUMknggoAACWPoJJPBBUAAEoeQSWffINpGaMCAEDJIajkExUVAABKHkElnwgqAACUPIJKPnF6MgAAJY+gkk/x8fbvzz9LxrhtCwAA5QVBJZ9q1bJ/T52SDh922xYAAMoLgko+eb3ZZ/7s3++0KQAAlBsElQJISLB/9+1z2w4AAMoLgkoB+IIKFRUAAEoGQaUAfONUqKgAAFAyCCoFQNcPAAAli6BSAHT9AABQsggqBUDXDwAAJYugUgBUVAAAKFkElQJgjAoAACWLoFIAvq6f5GQpK8ttWwAAKA8IKgUQHy95PFJmpnTwoOvWAABQ9hFUCqBiRalGDft87163bQEAoDwgqBRQgwb2748/um0HAADlAUGlgJo2tX+3bHHbDgAAygOCSgERVAAAKDkElQIiqAAAUHIIKgVEUAEAoOQQVArIF1R275aOHnXbFgAAyjqCSgFVq2YfkrR1q9u2AABQ1hFUCoHuHwAASgZBpRAuvND+3bzZbTsAACjrCCqF0Lat/bt6tdt2AABQ1jkNKh9//LH69u2r2rVry+PxaP78+S6bk2+dOtm/X3whGeO2LQAAlGVOg8rRo0fVtm1b/e1vf3PZjAK75BKpQgV7v5/du123BgCAsivc5YcnJSUpKSnJZRMKJSpKuugiae1aW1WpW9d1iwAAKJtK1RiVjIwMpaamBjxcObP7BwAAFI9SFVSmTJmi2NhY/6NevXrO2nLppfbvp586awIAAGVeqQoq48ePV0pKiv+xa9cuZ23p1UvyeKSVK6Xt2501AwCAMq1UBRWv16uYmJiAhyt160q/+IV9/q9/OWsGAABlWqkKKqHmllvs35kzOU0ZAIDi4DSopKena+3atVq7dq0kadu2bVq7dq127tzpsln5NnCgVKmS9N130tdfu24NAABlj9Ogsnr1arVr107t2rWTJP3ud79Tu3btNHHiRJfNyreYGKl/f/t85kynTQEAoEzyGFN6Oy1SU1MVGxurlJQUZ+NV3nlHuu46qWZNe/G3cKdXpgEAIPQV5PjNGJXz1KuXFB8vJSdLS5e6bg0AAGULQeU8Vawo3XyzfU73DwAARYugUgRuvdX+nT9fSktz2hQAAMoUgkoR6NBBatZMOn5cmjvXdWsAACg7CCpFwOPJrqo88ojk8BZEAACUKQSVInLXXVLDhtKPP0qjR7tuDQAAZQNBpYjExkqzZklhYdLrr9sHAAA4PwSVItSlizRpkn0+apS0Y4fb9gAAUNoRVIrYAw9Il19uz/4ZPZp7AAEAcD4IKkUsPFx6+WUpIsJetfaZZwgrAAAUFkGlGLRoIU2YYJ//7ne2Gygry22bAAAojQgqxeTBB6UnnrCDa194QbrjDunIEdetAgCgdCGoFBOPR/r976XXXrNhZcYM6YILpA8+cN0yAABKD4JKMRs8WFq4UGrZUjp8WPrlL6Vnn5V++sl1ywAACH0ElRKQlCR9/bW90/LRo9KYMVKrVtKqVa5bBgBAaAt33YDywuuV5s2TnnvOXhhu/Xrp6qul22+3l9yvWVO6/np73yBjbNcRAADlnceY0nvybGpqqmJjY5WSkqKYmBjXzcm3tDSpTx/ps88Cp3s8dtDtvHlSx47SxIm2u+iFF6Rx46Ru3ezZQ5s3S82bS+npUmamFBfnZDMAACiUghy/CSqOnDolzZ8vvfuuVKeOtG6dtGBB8OXr1ZM2bpRuucUud/fd0ttv2zs2r1snJSZKX31lB+62a1dimxHS0tNtGLz6aipUABBKCCqlkDH2zssvvGCvu7Jli62sHD0qRUfb7qFGjaRt28597623SjfcIPXvb9fTtav00kvShRfaUJOcLA0YIMXHS3Pm2JB000021EjSn/4kbdggvfKKFBVl52dl2e6qwmzHrl02WLkOB7/+tfTqq9K0adKdd7ptCwAgG0GljMjIkE6ftpWTwYPttLAwqU0bae1a+9wY+6hY0QYMn0qVpFq1su83VKmSPV360Uft6/bt7fiYX/zCDuw1xnY11aljbwMQGWlPqV6+3FZyKlSQLr7YDgT+4AOpbVsbZF5+WfriC2nEiOwzmsaMsV1VTz6Z83ZlZUmrV9uKx2WXZYejihWD7wvfuJ3Tp+3r8DxGV506ZYNZSop0xRXSsmV2G/J6HwCg+BFUyhhjpOeft908fftKVavasSzXXmtPc54yxS7Xo4etigwfbg/MklSliq3ErF+f87orV7ZVm/yKipKOHbNtiIqSdu+208PDpdmzpbvukvbts9NuvVWqVk2KiZG+/dYOFG7RQnr8cenzz+0yDRva7Zgxww40vuEGW905ccLeMyk5WRo2zIalm2+2laKwMLvcCy/YWxV8/LFd/qqrsoPIBx/Y/eHj9dr9sGSJ9Pe/SwMH2vZINszExmYv++OP0kcfSdddZ8NOQXz3nR0YXa1awd4HAOVJgY7fphRLSUkxkkxKSorrpji1a5cxK1YYc/y4fZ2ZaczXXxuzdKkxycl2eps2tvZSpYoxa9caM3WqMZGRvnqMMbVq2b8xMXbeRRfZ1926GfPss8b8+c/GxMdnL+97NGtmzLXXBk4LCzt3ubMflSsbU7Vq4LSKFY2Jjs5+HRNjTIUKwdcxcqQxCxcGLn/ZZcaMH29MixY5v6dmTfs3IsKYv/zFmFtuMcbjMWbKFGNOnDBm6NDsZVu3NubDD4154AFjrr7amDp1jBk0yJh77zWmZUtjFiyw+/vkSWP27zfm3/+262rRwq5rwwZjLrzQmD/+0S53+rT9m5VlzOrVdt0AUB4V5PhNRaWc+O47W4UZPlwaOtRO+/e/bZWiaVNblZg7154inZBgqzdbtkgXXZQ91mT3bjsAuHdvafJkO27m1VftWUfDhtnTriU75iUiwp6xlJEhHTggNW4s/fe/trvnyitt91JamtSpk11Py5a26iJJtWvbLp7kZPu6Uyd7xtObb0r33GPHv9x0k40TFSrYM5+8XvtZZ+vS5dyzq4Jp0MB2lYWF2a6vvCpNHo/0f/9n1//TT/a177+mP/7Rjg/68svsbfjmG9uFt2qVHQAtSffeK9WoYSs4J0/a5UaNsoOj//pXu08ef9wus2aN9OGHdtD00qW2u+3RR+24pQkT7Nlg/frZgdbvvGOrXt262apaTjIz7W+gZUvbreezfr1td9269to/rscaFbUjR+z3kJRkK4AASh4VFeTb118bs3v3+a8nK8uY55+3VY709Py/b/t2Ww1KSzNm0iRblcjMNObUKWNWrTLmp5/sus/217/aqohkTMeO9jPXrTPmtdeMGTzYVjK6d7cVpVGjjPnvf7MrOKNGGTNtmq0iRUUZM2BAdhUlNtaYJUtsm7xeW+UZNMguv2CBraTUqWPMTTflXLFp0iTvapKvohNsntdrTO3a2a9btDBm2DBbrTl72fvuM6Z588BpjRtnP2/QwJjhw41p186Ytm3tY+JEW9351a+ylxsxwu7z994L/Jzhw4156y1bpXr4YVs5ysiwn9url13H735nzBtvGHP4sDFffmnMXXcZ89ln9ns6fdqYTZuMef99u8yPP577XR47Zsznn9vvdN263H8vhw/n/HvIzQ8/2N+Uz9NPZ2/fihUFWxeAolGQ4zdBBaXW4cM2gBw5kr/l58835vbbjTl0yL7OyDAmJcUexObONeadd7LnGWNDUnJy4DqysrK7cFautAfwZ54xZs8e29V24oQNSr4D4TPPGNOnjzE9exrzz38aM2SInXbwoDEvvmhDxYAB9iD9yivGXHFF9nujo41JSAgMIddcY8z999vHmdPr1DHm5psDA9eZYefsR5cu9u+ZXWsDBmR377Vtm3MwiooyplOnnNcZF2dMpUrZrzt3NuaCCwKXiYy02//QQ8aMHWsD5dndhnffbb/TEyeMefllY2691Zjf/taYp54yJjzcbuezzxqTlGTM99/bkDtihO2e++QTG5g3bbLf0VNP2fX26GHM0aN2Wteu2Z/XtKntuktPN+aXvzTmuuvsOs524oQxM2bY0BPMrl3GjBljfwdny8qyv7877jBm69bAeRkZwdfp+2ygrCGoAI4dPpz7QS2YrCw7tuWjj4z5+Wdjdu60B/UxY4x5993AZf/xDztGKDbWhqbUVPs6MtK+PyXFVjeGDbNB7L337AHdd5AODzfmP/+xjzPHFbVta6sc//mPMR062GAybJitXPmWqVDBjmV66ikbLM4MHK1bBwagqChb9Tk7lJz5iImxwebMkJaYmHdlqlYtYxo2PHd67do2cJxZuQoPt2O0zl524UJbOTpzWrdudvxSz57G3HCDMb172+lVq9rxWsOG2fbOmWO/izlzsit2lSrZitPAgbZiOH68DUS+dV9xRXZV6D//seO1br3VVrTONmGCDYxDhtjfw9l27TJm48bs16tXZ1ezcvptLV0aWEHNrTq1cqWtcvK/VxQHggpQjvgqPMbYgHHgQPBljx+33UAREfYg6TNzpjG/+IUNRWdXkXyysmz3V//+ge/1tWH2bGP+9jf7fPdu+/yVV7K7AjMybCBo1cqYX//aVlRmzrTVJd8Bc+lSG7Z8B/WEBGMefDC7O6tt2+x5Zw7GTkzMrpSc3a122WXnDtzu0MGGP19VRbKB4P/+zwaavAKS7+HxZFenfKHszPlnfq7Xax+S7U7btCkwOF1xhR2ofcMNxkyfbsPmmQHy0kttOGzTxla/+vWzbQ0LsyH0gw9sQPR4jHnpJWMee8x2xW3YYMyiRcb84Q/Z+3T/fhtoqla1gergwezvMiUlsGI3dqyd/sILdlsfftiYvXsDv/tXXrEVwiVLbCXS93sBgiGoAAgqIyPwwBRqTp2yY6feecd26xhjw84XX9iD4sMPG/Pkk7bC8MQTtppx5Ig9MO7YYcynn2Z3W3XsaMy+fbbbZ+tWe3CX7EH1yy8DQ8X999vP2r7dVk2efNJ29/z613ac0Ny5dkzOVVfZM8Fuuy0wsNx/vz2At2xpg1Plytnzn3nGditOmpRdMfKdade6tR0LFSwM9eplq2a5Bab4eGOqVct/wOrZM7AL7MIL7T7csuXcSlb16nZfnRngqlSxQejYMTvm6ezwNnu2DVTdutluUd93OHas7bp7+20bfHbsMObvf7ftmTXLmG++sd/VZ58Zc889trssI8OG748+st//yZP2cbaUFPtdZ2UZs3x5dqUpI8Pu/1mzAsfPvf++HRuVmWk/98xxTOcrLS3n8VjFITMz+D8uQhlBBUC5tmSJrfqcffDJyrLdaVlZ9uHrjurYMeeDX258FaZXXzXm22+zp586ZefNmGEP2qNHZ887etSYK6/MPqi3aWMPrps32wP/yJG2glSnjp1fr54NPy++mP2e3/zGmOees6fXf/BB4Kn4l1ySvf7Wre3nh4cb06iRfX733edWfXyhLikp+7IEjRvb/edrx5ldYmd2AdaokV31ufJKO3hbCuz6q13bdkHmNNC8bt1zL2cQFRU4rX797MpU3brZ46Dq17cVqKFD7XPf8o0a2b+RkfYSAFddFRi6HnvMVs4k+znt29vn111nfwNbtxrz6KO2zYMG2e7NzEwbPtauDRzHdqYdO2xA27vX7qMKFc7thjtwwFa4UlPPfb+vKjl8uB17depU3r/JnTuzu0xHjsweh3W2776z3Yg5fa4rBBUAyIdFi2w3SnH96/fAgXO7QHxnTY0Ykfv4j5Mns4NWZqYx48bZ7puzw9cPP9hqzuuv2wPVyZO2MpKVZUOQb2yL7zpLCxdmV3CGD7fVkjMrOrVqZY9jmTAh8CC/Z4/9/Ndft6HBN+/WW+3yycmBlSTfdYt8j8REO9i8cWMbcnzT27e3XXBnVoWuvjq76nR2+Cnoo0qV7ACT26N58+zuuTMfHTpkf36VKvaMwshIG4JeeSU7+Pjm+5736ZP9Xf7619nTw8Jsde/zz+30lSttVenMClmNGjYoLlliA09Ghu3SGz7cVoZSU22QPbOdffoYs369rR75BmEfOZL9nYwfb0Px739vvzNfJSYrywaxFSts0PrTn2ygK86B3FxHBQAQ1HvvSa+/Lj31lL3Vxrvv2ntinTolTZpkr+cjSYcO2WsXNWhgr+9Tp072OjIy7K0wfvrJ3mfMd2+w8ePttX+6dLFXgp47V1q50t62Y+DA7Lu9r1tnb29RubK9dk98vD3crl5trzTdtau9ntPMmfaK0wMH2us9NW5sr/68dKl91K0rde5sr2SdliYtXGhv8fF//2fb1q6d9I9/SK1b2ytbv/OOVL++vYXI++/bx3XXSX/4g72WkWSv9dSli72e01NP2dt+SPaaROnpOe/TsDB7X7aUlMDpffrYa1L98IN9HRt77jJhYfYzKle215HasuXc9bduba/aLdnrL4WHS1On2qt7P/KIva7SiRPZ13Nq1Mi+Z+tWadMm+77Kle1f3zWi2rWTDh603+XJk/baV2e6+GK7v2rXznmbzweX0AcAOHHypDR9ur33V2Ji7svu22cPuDVqFH079u2zF7rs1i37Bqy52b3bXowxNtYGFd+FDufNs4+RI22AmzNH2r7drveDD6Q33rDb8MorNvTddpsNdMZkXwRTsheR/Pe/7W1QPv3UBpijRwMvVjlrlp0+ZYq93cjSpTboGRN4L7ewMNu+zEx7Ic3rrrMh7De/sfN9tzoJpk0be6HItLTA6dHRdlqjRvbvgQPSJZfYgOgLOUWFoAIAgEOHDtmbtsbF2apHu3a2EuSzbp30ySe2GjVmjJ3/0EPnricry96Etn9/W02Kj5feesvOGzzYVpx8VZRXX7WB6Je/tJWQtDR7r7VLL7U3iZ02zYaqdets2Bo71laeeve21aPu3W2FpVo1aedOG8wOHLCf/Z//5C/w5RdBBQCAMiQz0waFjAzbpXbBBfbO9/m9xcWWLdI119guot//Pn/vWbFCuvpqe0uQ3/62aG+nQVABAADnbe/evLvwCqMgx+8iLOQAAICypDhCSkERVAAAQMgiqAAAgJBFUAEAACGLoAIAAEIWQQUAAIQsggoAAAhZBBUAABCyCCoAACBkEVQAAEDIIqgAAICQRVABAAAhi6ACAABCFkEFAACErHDXDTgfxhhJ9nbRAACgdPAdt33H8dyU6qCSlpYmSapXr57jlgAAgIJKS0tTbGxsrst4TH7iTIjKysrSnj17FB0dLY/HUyTrTE1NVb169bRr1y7FxMQUyTpLG/YB+6C8b7/EPijv2y+xD6Ti2wfGGKWlpal27doKC8t9FEqprqiEhYWpbt26xbLumJiYcvvD9GEfsA/K+/ZL7IPyvv0S+0Aqnn2QVyXFh8G0AAAgZBFUAABAyCKonMXr9WrSpEnyer2um+IM+4B9UN63X2IflPftl9gHUmjsg1I9mBYAAJRtVFQAAEDIIqgAAICQRVABAAAhi6ACAABCFkHlLH/729/UsGFDVapUSZ06ddKXX37puknFYvLkyfJ4PAGP5s2b++efOHFCo0ePVvXq1VWlShUNHDhQ+/fvd9ji8/fxxx+rb9++ql27tjwej+bPnx8w3xijiRMnKjExUZGRkerZs6e2bNkSsMyhQ4c0ePBgxcTEKC4uTnfccYfS09NLcCvOT177YOjQoef8Lvr06ROwTGneB1OmTFHHjh0VHR2tmjVrqn///tq8eXPAMvn57e/cuVPXXnutoqKiVLNmTf3+97/X6dOnS3JTCiU/23/llVee8xu48847A5YprdsvSdOmTVObNm38FzDr3LmzFi9e7J9flr9/n7z2Qcj9Bgz8Zs+ebSIiIsw//vEPs3HjRjN8+HATFxdn9u/f77ppRW7SpEmmVatWZu/evf7Hzz//7J9/5513mnr16pn333/frF692lx22WWmS5cuDlt8/hYtWmQefPBBM3fuXCPJzJs3L2D+448/bmJjY838+fPNunXrzC9/+UvTqFEjc/z4cf8yffr0MW3btjWff/65+eSTT0yTJk3MoEGDSnhLCi+vfTBkyBDTp0+fgN/FoUOHApYpzfugd+/eZvr06WbDhg1m7dq15pprrjH169c36enp/mXy+u2fPn3atG7d2vTs2dOsWbPGLFq0yNSoUcOMHz/exSYVSH62v3v37mb48OEBv4GUlBT//NK8/cYYs2DBAvPOO++Y77//3mzevNk88MADpmLFimbDhg3GmLL9/fvktQ9C7TdAUDnDpZdeakaPHu1/nZmZaWrXrm2mTJnisFXFY9KkSaZt27Y5zjty5IipWLGieeutt/zTNm3aZCSZlStXllALi9fZB+msrCyTkJBgnnzySf+0I0eOGK/Xa9544w1jjDHffvutkWRWrVrlX2bx4sXG4/GY3bt3l1jbi0qwoNKvX7+g7ylr+yA5OdlIMh999JExJn+//UWLFpmwsDCzb98+/zLTpk0zMTExJiMjo2Q34Dydvf3G2IPUmDFjgr6nLG2/T9WqVc0rr7xS7r7/M/n2gTGh9xug6+d/Tp48qa+++ko9e/b0TwsLC1PPnj21cuVKhy0rPlu2bFHt2rV1wQUXaPDgwdq5c6ck6auvvtKpU6cC9kXz5s1Vv379Mrsvtm3bpn379gVsc2xsrDp16uTf5pUrVyouLk4dOnTwL9OzZ0+FhYXpiy++KPE2F5fly5erZs2aatasmUaOHKmDBw/655W1fZCSkiJJqlatmqT8/fZXrlypiy66SLVq1fIv07t3b6Wmpmrjxo0l2Przd/b2+8yaNUs1atRQ69atNX78eB07dsw/ryxtf2ZmpmbPnq2jR4+qc+fO5e77l87dBz6h9Bso1TclLEoHDhxQZmZmwI6XpFq1aum7775z1Kri06lTJ82YMUPNmjXT3r179fDDD6tr167asGGD9u3bp4iICMXFxQW8p1atWtq3b5+bBhcz33bl9P375u3bt081a9YMmB8eHq5q1aqVmf3Sp08fXX/99WrUqJF++OEHPfDAA0pKStLKlStVoUKFMrUPsrKyNHbsWF1++eVq3bq1JOXrt79v374cfye+eaVFTtsvSb/61a/UoEED1a5dW998843uu+8+bd68WXPnzpVUNrZ//fr16ty5s06cOKEqVapo3rx5atmypdauXVtuvv9g+0AKvd8AQaWcSkpK8j9v06aNOnXqpAYNGujNN99UZGSkw5bBpZtvvtn//KKLLlKbNm3UuHFjLV++XD169HDYsqI3evRobdiwQZ9++qnrpjgRbPtHjBjhf37RRRcpMTFRPXr00A8//KDGjRuXdDOLRbNmzbR27VqlpKRozpw5GjJkiD766CPXzSpRwfZBy5YtQ+43QNfP/9SoUUMVKlQ4Z3T3/v37lZCQ4KhVJScuLk4XXnihtm7dqoSEBJ08eVJHjhwJWKYs7wvfduX2/SckJCg5OTlg/unTp3Xo0KEyu18uuOAC1ahRQ1u3bpVUdvbBXXfdpYULF+rDDz9U3bp1/dPz89tPSEjI8Xfim1caBNv+nHTq1EmSAn4DpX37IyIi1KRJE7Vv315TpkxR27Zt9de//rXcfP9S8H2QE9e/AYLK/0RERKh9+/Z6//33/dOysrL0/vvvB/TblVXp6en64YcflJiYqPbt26tixYoB+2Lz5s3auXNnmd0XjRo1UkJCQsA2p6am6osvvvBvc+fOnXXkyBF99dVX/mU++OADZWVl+f9DLmt++uknHTx4UImJiZJK/z4wxuiuu+7SvHnz9MEHH6hRo0YB8/Pz2+/cubPWr18fENiWLl2qmJgYf+k8VOW1/TlZu3atJAX8Bkrr9geTlZWljIyMMv/958a3D3Li/DdQ5MNzS7HZs2cbr9drZsyYYb799lszYsQIExcXFzCyuay49957zfLly822bdvMihUrTM+ePU2NGjVMcnKyMcaeole/fn3zwQcfmNWrV5vOnTubzp07O271+UlLSzNr1qwxa9asMZLM008/bdasWWN27NhhjLGnJ8fFxZm3337bfPPNN6Zfv345np7crl0788UXX5hPP/3UNG3atNScmmtM7vsgLS3NjBs3zqxcudJs27bNLFu2zFxyySWmadOm5sSJE/51lOZ9MHLkSBMbG2uWL18ecOrlsWPH/Mvk9dv3nZrZq1cvs3btWvPuu++a+Pj4UnF6al7bv3XrVvPII4+Y1atXm23btpm3337bXHDBBaZbt27+dZTm7TfGmPvvv9989NFHZtu2beabb74x999/v/F4PGbJkiXGmLL9/fvktg9C8TdAUDnLc889Z+rXr28iIiLMpZdeaj7//HPXTSoWN910k0lMTDQRERGmTp065qabbjJbt271zz9+/LgZNWqUqVq1qomKijIDBgwwe/fuddji8/fhhx8aSec8hgwZYoyxpyhPmDDB1KpVy3i9XtOjRw+zefPmgHUcPHjQDBo0yFSpUsXExMSYYcOGmbS0NAdbUzi57YNjx46ZXr16mfj4eFOxYkXToEEDM3z48HOCemneBzltuyQzffp0/zL5+e1v377dJCUlmcjISFOjRg1z7733mlOnTpXw1hRcXtu/c+dO061bN1OtWjXj9XpNkyZNzO9///uAa2gYU3q33xhjbr/9dtOgQQMTERFh4uPjTY8ePfwhxZiy/f375LYPQvE34DHGmKKv0wAAAJw/xqgAAICQRVABAAAhi6ACAABCFkEFAACELIIKAAAIWQQVAAAQsggqAAAgZBFUAJR6Ho9H8+fPd90MAMWAoALgvAwdOlQej+ecR58+fVw3DUAZEO66AQBKvz59+mj69OkB07xer6PWAChLqKgAOG9er1cJCQkBj6pVq0qy3TLTpk1TUlKSIiMjdcEFF2jOnDkB71+/fr2uuuoqRUZGqnr16hoxYoTS09MDlvnHP/6hVq1ayev1KjExUXfddVfA/AMHDmjAgAGKiopS06ZNtWDBAv+8w4cPa/DgwYqPj1dkZKSaNm16TrACEJoIKgCK3YQJEzRw4ECtW7dOgwcP1s0336xNmzZJko4eParevXuratWqWrVqld566y0tW7YsIIhMmzZNo0eP1ogRI7R+/XotWLBATZo0CfiMhx9+WDfeeKO++eYbXXPNNRo8eLAOHTrk//xvv/1Wixcv1qZNmzRt2jTVqFGj5HYAgMIrllsdAig3hgwZYipUqGAqV64c8HjssceMMfaOvXfeeWfAezp16mRGjhxpjDHmpZdeMlWrVjXp6en++e+8844JCwvz37m5du3a5sEHHwzaBknmoYce8r9OT083kszixYuNMcb07dvXDBs2rGg2GECJYowKgPP2i1/8QtOmTQuYVq1aNf/zzp07B8zr3Lmz1q5dK0natGmT2rZtq8qVK/vnX3755crKytLmzZvl8Xi0Z88e9ejRI9c2tGnTxv+8cuXKiomJUXJysiRp5MiRGjhwoL7++mv16tVL/fv3V5cuXQq1rQBKFkEFwHmrXLnyOV0xRSUyMjJfy1WsWDHgtcfjUVZWliQpKSlJO3bs0KJFi7R06VL16NFDo0eP1lNPPVXk7QVQtBijAqDYff755+e8btGihSSpRYsWWrdunY4ePeqfv2LFCoWFhalZs2aKjo5Ww4YN9f77759XG+Lj4zVkyBC9/vrreuaZZ/TSSy+d1/oAlAwqKgDOW0ZGhvbt2xcwLTw83D9g9a233lKHDh10xRVXaNasWfryyy/16quvSpIGDx6sSZMmaciQIZo8ebJ+/vln3X333br11ltVq1YtSdLkyZN15513qmbNmkpKSlJaWppWrFihu+++O1/tmzhxotq3b69WrVopIyNDCxcu9AclAKGNoALgvL377rtKTEwMmNasWTN99913kuwZObNnz9aoUaOUmJioN954Qy1btpQkRUVF6b333tOYMWPUsWNHRUVFaeDAgXr66af96xoyZIhOnDihv/zlLxo3bpxq1KihG264Id/ti4iI0Pjx47V9+3ZFRkaqa9eumj17dhFsOYDi5jHGGNeNAFB2eTwezZs3T/3793fdFAClEGNUAABAyCKoAACAkMUYFQDFit5lAOeDigoAAAhZBBUAABCyCCoAACBkEVQAAEDIIqgAAICQRVABAAAhi6ACAABCFkEFAACELIIKAAAIWf8P2HJWN2E5DGAAAAAASUVORK5CYII=\n" + }, + "metadata": {} + } + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "3SRXs6V6bNEX", + "outputId": "4ed1f452-e232-41a3-caa9-ea01e89369d0" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "313/313 [==============================] - 1s 3ms/step - loss: 1.5913 - accuracy: 0.7413\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "[1.5913232564926147, 0.7412999868392944]" + ] + }, + "metadata": {}, + "execution_count": 23 + } + ], + "source": [ + "model.evaluate(x_test, to_categorical(y_test, num_classes))" + ] + }, + { + "cell_type": "markdown", + "source": [ + "---" + ], + "metadata": { + "id": "XyPoVUwrRRj5" + } + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": { + "id": "vXA26MZUbOXO" + }, + "outputs": [], + "source": [ + "def cnn(input_shape, num_classes):\n", + " \"\"\"CNN Model from (McMahan et. al., 2017).\n", + "\n", + " Communication-efficient learning of deep networks from decentralized data\n", + " \"\"\"\n", + " input_shape = tuple(input_shape)\n", + "\n", + " weight_decay = 0.004\n", + " model = keras.Sequential(\n", + " [\n", + " keras.layers.Conv2D(\n", + " 64,\n", + " (5, 5),\n", + " padding=\"same\",\n", + " activation=\"relu\",\n", + " input_shape=input_shape,\n", + " ),\n", + " keras.layers.MaxPooling2D((3, 3), strides=(2, 2)),\n", + " keras.layers.BatchNormalization(),\n", + " keras.layers.Conv2D(\n", + " 64,\n", + " (5, 5),\n", + " padding=\"same\",\n", + " activation=\"relu\",\n", + " ),\n", + " keras.layers.BatchNormalization(),\n", + " keras.layers.MaxPooling2D((3, 3), strides=(2, 2)),\n", + " keras.layers.Flatten(),\n", + " keras.layers.Dense(\n", + " 384, activation=\"relu\", kernel_regularizer=l2(weight_decay)\n", + " ),\n", + " keras.layers.Dense(\n", + " 192, activation=\"relu\", kernel_regularizer=l2(weight_decay)\n", + " ),\n", + " keras.layers.Dense(num_classes, activation=\"softmax\"),\n", + " ]\n", + " )\n", + " optimizer = SGD(learning_rate=0.1)\n", + " model.compile(\n", + " loss=\"categorical_crossentropy\", optimizer=optimizer, metrics=[\"accuracy\"]\n", + " )\n", + "\n", + " return model" + ] + }, + { + "cell_type": "code", + "source": [ + "model_cnn = cnn(input_shape, num_classes)" + ], + "metadata": { + "id": "t098yVNYRxPu" + }, + "execution_count": 25, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "history_cnn = model_cnn.fit(x_train, to_categorical(y_train, num_classes), epochs=350, batch_size=100)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "JinRA8quR2mr", + "outputId": "edc6a49c-3fa4-498d-fbb4-21cb439c9b38" + }, + "execution_count": 26, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Epoch 1/350\n", + "500/500 [==============================] - 4s 7ms/step - loss: 4.1634 - accuracy: 0.4622\n", + "Epoch 2/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 2.3282 - accuracy: 0.6234\n", + "Epoch 3/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 1.5241 - accuracy: 0.6978\n", + "Epoch 4/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 1.1409 - accuracy: 0.7442\n", + "Epoch 5/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.9513 - accuracy: 0.7783\n", + "Epoch 6/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.8526 - accuracy: 0.8004\n", + "Epoch 7/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.7955 - accuracy: 0.8228\n", + "Epoch 8/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.7653 - accuracy: 0.8402\n", + "Epoch 9/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.7479 - accuracy: 0.8540\n", + "Epoch 10/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.7359 - accuracy: 0.8678\n", + "Epoch 11/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.7267 - accuracy: 0.8774\n", + "Epoch 12/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.7274 - accuracy: 0.8839\n", + "Epoch 13/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.7191 - accuracy: 0.8918\n", + "Epoch 14/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.7182 - accuracy: 0.8971\n", + "Epoch 15/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.7166 - accuracy: 0.9014\n", + "Epoch 16/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.7239 - accuracy: 0.9033\n", + "Epoch 17/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.7214 - accuracy: 0.9069\n", + "Epoch 18/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.7103 - accuracy: 0.9122\n", + "Epoch 19/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.7023 - accuracy: 0.9168\n", + "Epoch 20/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.7128 - accuracy: 0.9147\n", + "Epoch 21/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.7064 - accuracy: 0.9197\n", + "Epoch 22/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.7090 - accuracy: 0.9177\n", + "Epoch 23/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.7103 - accuracy: 0.9190\n", + "Epoch 24/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6981 - accuracy: 0.9232\n", + "Epoch 25/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.7015 - accuracy: 0.9234\n", + "Epoch 26/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.7026 - accuracy: 0.9253\n", + "Epoch 27/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6889 - accuracy: 0.9264\n", + "Epoch 28/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6924 - accuracy: 0.9275\n", + "Epoch 29/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6818 - accuracy: 0.9303\n", + "Epoch 30/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6961 - accuracy: 0.9273\n", + "Epoch 31/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6967 - accuracy: 0.9277\n", + "Epoch 32/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6932 - accuracy: 0.9318\n", + "Epoch 33/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6812 - accuracy: 0.9331\n", + "Epoch 34/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6779 - accuracy: 0.9321\n", + "Epoch 35/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6898 - accuracy: 0.9312\n", + "Epoch 36/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6800 - accuracy: 0.9328\n", + "Epoch 37/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6785 - accuracy: 0.9340\n", + "Epoch 38/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6713 - accuracy: 0.9370\n", + "Epoch 39/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6832 - accuracy: 0.9345\n", + "Epoch 40/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6778 - accuracy: 0.9349\n", + "Epoch 41/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6610 - accuracy: 0.9378\n", + "Epoch 42/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6612 - accuracy: 0.9385\n", + "Epoch 43/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6545 - accuracy: 0.9393\n", + "Epoch 44/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6609 - accuracy: 0.9369\n", + "Epoch 45/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6648 - accuracy: 0.9382\n", + "Epoch 46/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6587 - accuracy: 0.9385\n", + "Epoch 47/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6492 - accuracy: 0.9420\n", + "Epoch 48/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6523 - accuracy: 0.9404\n", + "Epoch 49/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6648 - accuracy: 0.9378\n", + "Epoch 50/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6571 - accuracy: 0.9397\n", + "Epoch 51/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6493 - accuracy: 0.9413\n", + "Epoch 52/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6590 - accuracy: 0.9388\n", + "Epoch 53/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6542 - accuracy: 0.9412\n", + "Epoch 54/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6526 - accuracy: 0.9427\n", + "Epoch 55/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6311 - accuracy: 0.9462\n", + "Epoch 56/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6459 - accuracy: 0.9412\n", + "Epoch 57/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6436 - accuracy: 0.9438\n", + "Epoch 58/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6429 - accuracy: 0.9440\n", + "Epoch 59/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6459 - accuracy: 0.9421\n", + "Epoch 60/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6418 - accuracy: 0.9432\n", + "Epoch 61/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6357 - accuracy: 0.9444\n", + "Epoch 62/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6316 - accuracy: 0.9452\n", + "Epoch 63/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6348 - accuracy: 0.9451\n", + "Epoch 64/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6293 - accuracy: 0.9447\n", + "Epoch 65/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6339 - accuracy: 0.9453\n", + "Epoch 66/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6223 - accuracy: 0.9482\n", + "Epoch 67/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6169 - accuracy: 0.9483\n", + "Epoch 68/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6217 - accuracy: 0.9456\n", + "Epoch 69/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6262 - accuracy: 0.9456\n", + "Epoch 70/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6168 - accuracy: 0.9488\n", + "Epoch 71/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6166 - accuracy: 0.9465\n", + "Epoch 72/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6248 - accuracy: 0.9458\n", + "Epoch 73/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6089 - accuracy: 0.9510\n", + "Epoch 74/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6155 - accuracy: 0.9472\n", + "Epoch 75/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6207 - accuracy: 0.9480\n", + "Epoch 76/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6123 - accuracy: 0.9502\n", + "Epoch 77/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6173 - accuracy: 0.9474\n", + "Epoch 78/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6020 - accuracy: 0.9510\n", + "Epoch 79/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5970 - accuracy: 0.9512\n", + "Epoch 80/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6211 - accuracy: 0.9454\n", + "Epoch 81/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5945 - accuracy: 0.9522\n", + "Epoch 82/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6178 - accuracy: 0.9460\n", + "Epoch 83/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6082 - accuracy: 0.9504\n", + "Epoch 84/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5934 - accuracy: 0.9522\n", + "Epoch 85/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5979 - accuracy: 0.9512\n", + "Epoch 86/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5985 - accuracy: 0.9506\n", + "Epoch 87/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5924 - accuracy: 0.9520\n", + "Epoch 88/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5885 - accuracy: 0.9514\n", + "Epoch 89/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5934 - accuracy: 0.9515\n", + "Epoch 90/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6033 - accuracy: 0.9507\n", + "Epoch 91/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5958 - accuracy: 0.9523\n", + "Epoch 92/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5970 - accuracy: 0.9505\n", + "Epoch 93/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5900 - accuracy: 0.9536\n", + "Epoch 94/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5916 - accuracy: 0.9512\n", + "Epoch 95/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5955 - accuracy: 0.9519\n", + "Epoch 96/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5943 - accuracy: 0.9520\n", + "Epoch 97/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5852 - accuracy: 0.9523\n", + "Epoch 98/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5784 - accuracy: 0.9533\n", + "Epoch 99/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5800 - accuracy: 0.9535\n", + "Epoch 100/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5691 - accuracy: 0.9552\n", + "Epoch 101/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5720 - accuracy: 0.9531\n", + "Epoch 102/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5766 - accuracy: 0.9541\n", + "Epoch 103/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5696 - accuracy: 0.9543\n", + "Epoch 104/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5753 - accuracy: 0.9538\n", + "Epoch 105/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5765 - accuracy: 0.9540\n", + "Epoch 106/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5590 - accuracy: 0.9576\n", + "Epoch 107/350\n", + "500/500 [==============================] - 4s 7ms/step - loss: 0.5675 - accuracy: 0.9537\n", + "Epoch 108/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5797 - accuracy: 0.9523\n", + "Epoch 109/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5745 - accuracy: 0.9549\n", + "Epoch 110/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5634 - accuracy: 0.9565\n", + "Epoch 111/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5626 - accuracy: 0.9556\n", + "Epoch 112/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5731 - accuracy: 0.9542\n", + "Epoch 113/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5737 - accuracy: 0.9539\n", + "Epoch 114/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5690 - accuracy: 0.9557\n", + "Epoch 115/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5670 - accuracy: 0.9558\n", + "Epoch 116/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5583 - accuracy: 0.9550\n", + "Epoch 117/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5479 - accuracy: 0.9570\n", + "Epoch 118/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5639 - accuracy: 0.9541\n", + "Epoch 119/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5530 - accuracy: 0.9580\n", + "Epoch 120/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5579 - accuracy: 0.9562\n", + "Epoch 121/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5522 - accuracy: 0.9573\n", + "Epoch 122/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5641 - accuracy: 0.9542\n", + "Epoch 123/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5519 - accuracy: 0.9582\n", + "Epoch 124/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5387 - accuracy: 0.9588\n", + "Epoch 125/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5470 - accuracy: 0.9570\n", + "Epoch 126/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5583 - accuracy: 0.9545\n", + "Epoch 127/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5439 - accuracy: 0.9590\n", + "Epoch 128/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5541 - accuracy: 0.9557\n", + "Epoch 129/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5357 - accuracy: 0.9598\n", + "Epoch 130/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5512 - accuracy: 0.9564\n", + "Epoch 131/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5377 - accuracy: 0.9593\n", + "Epoch 132/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5414 - accuracy: 0.9568\n", + "Epoch 133/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5476 - accuracy: 0.9556\n", + "Epoch 134/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5348 - accuracy: 0.9583\n", + "Epoch 135/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5451 - accuracy: 0.9572\n", + "Epoch 136/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5413 - accuracy: 0.9579\n", + "Epoch 137/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5257 - accuracy: 0.9604\n", + "Epoch 138/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5314 - accuracy: 0.9585\n", + "Epoch 139/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5326 - accuracy: 0.9591\n", + "Epoch 140/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5399 - accuracy: 0.9575\n", + "Epoch 141/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5402 - accuracy: 0.9588\n", + "Epoch 142/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5342 - accuracy: 0.9576\n", + "Epoch 143/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5380 - accuracy: 0.9577\n", + "Epoch 144/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5313 - accuracy: 0.9587\n", + "Epoch 145/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5303 - accuracy: 0.9589\n", + "Epoch 146/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5219 - accuracy: 0.9595\n", + "Epoch 147/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5418 - accuracy: 0.9567\n", + "Epoch 148/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5300 - accuracy: 0.9600\n", + "Epoch 149/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5088 - accuracy: 0.9607\n", + "Epoch 150/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5319 - accuracy: 0.9577\n", + "Epoch 151/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5168 - accuracy: 0.9621\n", + "Epoch 152/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5161 - accuracy: 0.9606\n", + "Epoch 153/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5109 - accuracy: 0.9618\n", + "Epoch 154/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5190 - accuracy: 0.9593\n", + "Epoch 155/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5273 - accuracy: 0.9586\n", + "Epoch 156/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5093 - accuracy: 0.9630\n", + "Epoch 157/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5210 - accuracy: 0.9589\n", + "Epoch 158/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5046 - accuracy: 0.9636\n", + "Epoch 159/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5142 - accuracy: 0.9598\n", + "Epoch 160/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5294 - accuracy: 0.9588\n", + "Epoch 161/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5079 - accuracy: 0.9622\n", + "Epoch 162/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4972 - accuracy: 0.9635\n", + "Epoch 163/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5120 - accuracy: 0.9601\n", + "Epoch 164/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5091 - accuracy: 0.9624\n", + "Epoch 165/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5115 - accuracy: 0.9612\n", + "Epoch 166/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5071 - accuracy: 0.9614\n", + "Epoch 167/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5067 - accuracy: 0.9628\n", + "Epoch 168/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5002 - accuracy: 0.9623\n", + "Epoch 169/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5094 - accuracy: 0.9602\n", + "Epoch 170/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5032 - accuracy: 0.9618\n", + "Epoch 171/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5024 - accuracy: 0.9618\n", + "Epoch 172/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4839 - accuracy: 0.9649\n", + "Epoch 173/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4931 - accuracy: 0.9620\n", + "Epoch 174/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5061 - accuracy: 0.9614\n", + "Epoch 175/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5023 - accuracy: 0.9620\n", + "Epoch 176/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5021 - accuracy: 0.9625\n", + "Epoch 177/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4821 - accuracy: 0.9651\n", + "Epoch 178/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4813 - accuracy: 0.9626\n", + "Epoch 179/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4906 - accuracy: 0.9630\n", + "Epoch 180/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4973 - accuracy: 0.9611\n", + "Epoch 181/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4970 - accuracy: 0.9629\n", + "Epoch 182/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4841 - accuracy: 0.9644\n", + "Epoch 183/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4872 - accuracy: 0.9631\n", + "Epoch 184/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4845 - accuracy: 0.9647\n", + "Epoch 185/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4756 - accuracy: 0.9648\n", + "Epoch 186/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4821 - accuracy: 0.9626\n", + "Epoch 187/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4819 - accuracy: 0.9633\n", + "Epoch 188/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5000 - accuracy: 0.9617\n", + "Epoch 189/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4783 - accuracy: 0.9652\n", + "Epoch 190/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4778 - accuracy: 0.9641\n", + "Epoch 191/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4815 - accuracy: 0.9623\n", + "Epoch 192/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4892 - accuracy: 0.9640\n", + "Epoch 193/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4850 - accuracy: 0.9637\n", + "Epoch 194/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4871 - accuracy: 0.9641\n", + "Epoch 195/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4739 - accuracy: 0.9651\n", + "Epoch 196/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4779 - accuracy: 0.9636\n", + "Epoch 197/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4658 - accuracy: 0.9663\n", + "Epoch 198/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4821 - accuracy: 0.9623\n", + "Epoch 199/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4826 - accuracy: 0.9635\n", + "Epoch 200/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4732 - accuracy: 0.9656\n", + "Epoch 201/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4790 - accuracy: 0.9648\n", + "Epoch 202/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4675 - accuracy: 0.9658\n", + "Epoch 203/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4743 - accuracy: 0.9633\n", + "Epoch 204/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4667 - accuracy: 0.9653\n", + "Epoch 205/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4760 - accuracy: 0.9624\n", + "Epoch 206/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4736 - accuracy: 0.9651\n", + "Epoch 207/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4744 - accuracy: 0.9636\n", + "Epoch 208/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4623 - accuracy: 0.9664\n", + "Epoch 209/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4591 - accuracy: 0.9670\n", + "Epoch 210/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4700 - accuracy: 0.9645\n", + "Epoch 211/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4690 - accuracy: 0.9653\n", + "Epoch 212/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4686 - accuracy: 0.9649\n", + "Epoch 213/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4546 - accuracy: 0.9667\n", + "Epoch 214/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4710 - accuracy: 0.9645\n", + "Epoch 215/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4677 - accuracy: 0.9653\n", + "Epoch 216/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4796 - accuracy: 0.9629\n", + "Epoch 217/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4601 - accuracy: 0.9673\n", + "Epoch 218/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4571 - accuracy: 0.9667\n", + "Epoch 219/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4652 - accuracy: 0.9648\n", + "Epoch 220/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4613 - accuracy: 0.9658\n", + "Epoch 221/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4510 - accuracy: 0.9679\n", + "Epoch 222/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4689 - accuracy: 0.9653\n", + "Epoch 223/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4490 - accuracy: 0.9677\n", + "Epoch 224/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4579 - accuracy: 0.9645\n", + "Epoch 225/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4465 - accuracy: 0.9682\n", + "Epoch 226/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4486 - accuracy: 0.9673\n", + "Epoch 227/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4628 - accuracy: 0.9638\n", + "Epoch 228/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4438 - accuracy: 0.9689\n", + "Epoch 229/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4528 - accuracy: 0.9650\n", + "Epoch 230/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4560 - accuracy: 0.9656\n", + "Epoch 231/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4532 - accuracy: 0.9670\n", + "Epoch 232/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4497 - accuracy: 0.9671\n", + "Epoch 233/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4474 - accuracy: 0.9675\n", + "Epoch 234/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4517 - accuracy: 0.9672\n", + "Epoch 235/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4531 - accuracy: 0.9660\n", + "Epoch 236/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4524 - accuracy: 0.9662\n", + "Epoch 237/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4423 - accuracy: 0.9669\n", + "Epoch 238/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4500 - accuracy: 0.9658\n", + "Epoch 239/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4564 - accuracy: 0.9653\n", + "Epoch 240/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4247 - accuracy: 0.9709\n", + "Epoch 241/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4395 - accuracy: 0.9670\n", + "Epoch 242/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4506 - accuracy: 0.9656\n", + "Epoch 243/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4332 - accuracy: 0.9697\n", + "Epoch 244/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4385 - accuracy: 0.9674\n", + "Epoch 245/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4414 - accuracy: 0.9672\n", + "Epoch 246/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4494 - accuracy: 0.9664\n", + "Epoch 247/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4429 - accuracy: 0.9677\n", + "Epoch 248/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4385 - accuracy: 0.9683\n", + "Epoch 249/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4283 - accuracy: 0.9697\n", + "Epoch 250/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4362 - accuracy: 0.9677\n", + "Epoch 251/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4360 - accuracy: 0.9678\n", + "Epoch 252/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4337 - accuracy: 0.9684\n", + "Epoch 253/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4485 - accuracy: 0.9664\n", + "Epoch 254/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4364 - accuracy: 0.9686\n", + "Epoch 255/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4394 - accuracy: 0.9681\n", + "Epoch 256/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4211 - accuracy: 0.9692\n", + "Epoch 257/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4226 - accuracy: 0.9694\n", + "Epoch 258/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4358 - accuracy: 0.9669\n", + "Epoch 259/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4260 - accuracy: 0.9696\n", + "Epoch 260/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4276 - accuracy: 0.9690\n", + "Epoch 261/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4286 - accuracy: 0.9683\n", + "Epoch 262/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4297 - accuracy: 0.9690\n", + "Epoch 263/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4210 - accuracy: 0.9696\n", + "Epoch 264/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4301 - accuracy: 0.9681\n", + "Epoch 265/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4283 - accuracy: 0.9687\n", + "Epoch 266/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4182 - accuracy: 0.9713\n", + "Epoch 267/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4202 - accuracy: 0.9681\n", + "Epoch 268/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4292 - accuracy: 0.9686\n", + "Epoch 269/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4213 - accuracy: 0.9699\n", + "Epoch 270/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4239 - accuracy: 0.9688\n", + "Epoch 271/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4213 - accuracy: 0.9686\n", + "Epoch 272/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4120 - accuracy: 0.9706\n", + "Epoch 273/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4196 - accuracy: 0.9697\n", + "Epoch 274/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4181 - accuracy: 0.9694\n", + "Epoch 275/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4176 - accuracy: 0.9692\n", + "Epoch 276/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4199 - accuracy: 0.9693\n", + "Epoch 277/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4298 - accuracy: 0.9686\n", + "Epoch 278/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4201 - accuracy: 0.9696\n", + "Epoch 279/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4046 - accuracy: 0.9715\n", + "Epoch 280/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4024 - accuracy: 0.9707\n", + "Epoch 281/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4200 - accuracy: 0.9685\n", + "Epoch 282/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4064 - accuracy: 0.9710\n", + "Epoch 283/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3937 - accuracy: 0.9725\n", + "Epoch 284/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4107 - accuracy: 0.9690\n", + "Epoch 285/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4096 - accuracy: 0.9709\n", + "Epoch 286/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4103 - accuracy: 0.9696\n", + "Epoch 287/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4256 - accuracy: 0.9673\n", + "Epoch 288/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4131 - accuracy: 0.9715\n", + "Epoch 289/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3974 - accuracy: 0.9726\n", + "Epoch 290/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4112 - accuracy: 0.9703\n", + "Epoch 291/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3966 - accuracy: 0.9720\n", + "Epoch 292/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4136 - accuracy: 0.9687\n", + "Epoch 293/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4141 - accuracy: 0.9709\n", + "Epoch 294/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4057 - accuracy: 0.9714\n", + "Epoch 295/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3946 - accuracy: 0.9728\n", + "Epoch 296/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4176 - accuracy: 0.9682\n", + "Epoch 297/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4128 - accuracy: 0.9701\n", + "Epoch 298/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4058 - accuracy: 0.9712\n", + "Epoch 299/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3820 - accuracy: 0.9740\n", + "Epoch 300/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4019 - accuracy: 0.9694\n", + "Epoch 301/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4020 - accuracy: 0.9713\n", + "Epoch 302/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4151 - accuracy: 0.9681\n", + "Epoch 303/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3961 - accuracy: 0.9724\n", + "Epoch 304/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3941 - accuracy: 0.9709\n", + "Epoch 305/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3992 - accuracy: 0.9710\n", + "Epoch 306/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3982 - accuracy: 0.9722\n", + "Epoch 307/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3842 - accuracy: 0.9727\n", + "Epoch 308/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4010 - accuracy: 0.9696\n", + "Epoch 309/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4085 - accuracy: 0.9690\n", + "Epoch 310/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3894 - accuracy: 0.9731\n", + "Epoch 311/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4014 - accuracy: 0.9696\n", + "Epoch 312/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3918 - accuracy: 0.9729\n", + "Epoch 313/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3906 - accuracy: 0.9708\n", + "Epoch 314/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3815 - accuracy: 0.9746\n", + "Epoch 315/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3944 - accuracy: 0.9701\n", + "Epoch 316/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4006 - accuracy: 0.9704\n", + "Epoch 317/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3836 - accuracy: 0.9748\n", + "Epoch 318/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3836 - accuracy: 0.9722\n", + "Epoch 319/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3873 - accuracy: 0.9715\n", + "Epoch 320/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3858 - accuracy: 0.9728\n", + "Epoch 321/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3900 - accuracy: 0.9710\n", + "Epoch 322/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3927 - accuracy: 0.9719\n", + "Epoch 323/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3863 - accuracy: 0.9711\n", + "Epoch 324/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3857 - accuracy: 0.9726\n", + "Epoch 325/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3778 - accuracy: 0.9728\n", + "Epoch 326/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3951 - accuracy: 0.9698\n", + "Epoch 327/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3871 - accuracy: 0.9726\n", + "Epoch 328/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3910 - accuracy: 0.9707\n", + "Epoch 329/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3787 - accuracy: 0.9735\n", + "Epoch 330/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3874 - accuracy: 0.9707\n", + "Epoch 331/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3901 - accuracy: 0.9715\n", + "Epoch 332/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3710 - accuracy: 0.9741\n", + "Epoch 333/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3874 - accuracy: 0.9715\n", + "Epoch 334/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3874 - accuracy: 0.9722\n", + "Epoch 335/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3768 - accuracy: 0.9730\n", + "Epoch 336/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3739 - accuracy: 0.9738\n", + "Epoch 337/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3883 - accuracy: 0.9711\n", + "Epoch 338/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3715 - accuracy: 0.9732\n", + "Epoch 339/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3740 - accuracy: 0.9730\n", + "Epoch 340/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3902 - accuracy: 0.9715\n", + "Epoch 341/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3779 - accuracy: 0.9727\n", + "Epoch 342/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3883 - accuracy: 0.9708\n", + "Epoch 343/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3739 - accuracy: 0.9741\n", + "Epoch 344/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3823 - accuracy: 0.9714\n", + "Epoch 345/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3729 - accuracy: 0.9736\n", + "Epoch 346/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3730 - accuracy: 0.9731\n", + "Epoch 347/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3716 - accuracy: 0.9722\n", + "Epoch 348/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3841 - accuracy: 0.9722\n", + "Epoch 349/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3659 - accuracy: 0.9750\n", + "Epoch 350/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3740 - accuracy: 0.9721\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "model_cnn.evaluate(x_test, to_categorical(y_test, num_classes))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "eIzsv0QLR_tt", + "outputId": "0a7eb8e7-4d7f-40ef-e1b0-b3979cc65854" + }, + "execution_count": 27, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "313/313 [==============================] - 1s 2ms/step - loss: 1.4919 - accuracy: 0.7581\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "[1.491928219795227, 0.7580999732017517]" + ] + }, + "metadata": {}, + "execution_count": 27 + } + ] + }, + { + "cell_type": "code", + "source": [ + "loss = history_cnn.history['loss']\n", + "epochs = range(1, len(loss) + 1)\n", + "\n", + "plt.plot(epochs, loss, 'b', label='Training Loss')\n", + "plt.title('Training Loss')\n", + "plt.xlabel('Epochs')\n", + "plt.ylabel('Loss')\n", + "plt.legend()\n", + "plt.show()\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 472 + }, + "id": "rOXE49XhSGBy", + "outputId": "7bf4879e-632d-4762-977f-522402eee644" + }, + "execution_count": 28, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ], + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjcAAAHHCAYAAABDUnkqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABOn0lEQVR4nO3dd3hUZeL28XsSkiGBJISSAoQisPQmNaCAgkJgWUB2VcQlYFsQXLDsKlZA3eBiF0VZlawFWeEFVASRIiC9CApIEUWCkgQRkhBKCMl5/3h+GRhSCTM5Kd/Pdc2VzKnPORmdm6cdh2VZlgAAAMoJH7sLAAAA4EmEGwAAUK4QbgAAQLlCuAEAAOUK4QYAAJQrhBsAAFCuEG4AAEC5QrgBAADlCuEGAACUK4QbAF43cuRINWjQoFj7Tpo0SQ6Hw7MFAlCuEW6ACszhcBTptWrVKruLaouRI0eqatWqdhcDwGVy8GwpoOL64IMP3N6/9957WrZsmd5//3235TfccIPCw8OLfZ7MzExlZ2fL6XRe9r7nz5/X+fPnVbly5WKfv7hGjhypefPmKT09vcTPDaD4KtldAAD2uf32293eb9y4UcuWLcu1/FKnT59WYGBgkc/j5+dXrPJJUqVKlVSpEv+rAlB0NEsBKFCvXr3UqlUrbdu2TT169FBgYKAeffRRSdInn3yiAQMGqHbt2nI6nWrUqJGefvppZWVluR3j0j43P//8sxwOh55//nnNnDlTjRo1ktPpVKdOnbRlyxa3ffPqc+NwODRu3DgtXLhQrVq1ktPpVMuWLfXFF1/kKv+qVavUsWNHVa5cWY0aNdJbb73l8X48c+fOVYcOHRQQEKCaNWvq9ttv16+//uq2TVJSkkaNGqW6devK6XQqMjJSgwYN0s8//+zaZuvWrerbt69q1qypgIAANWzYUHfccYfHyglUFPxzCEChfv/9d8XExOjWW2/V7bff7mqiio+PV9WqVfXAAw+oatWqWrlypZ588kmlpaVp2rRphR539uzZOnnypP72t7/J4XDo3//+t2666Sb99NNPhdb2rF27VvPnz9e9996roKAgvfrqqxo6dKgSEhJUo0YNSdL27dvVr18/RUZGavLkycrKytKUKVNUq1atK78p/yc+Pl6jRo1Sp06dFBcXp+TkZL3yyitat26dtm/frmrVqkmShg4dqt27d+u+++5TgwYNdPToUS1btkwJCQmu9zfeeKNq1aqlRx55RNWqVdPPP/+s+fPne6ysQIVhAcD/GTt2rHXp/xZ69uxpSbLefPPNXNufPn0617K//e1vVmBgoHX27FnXstjYWKt+/fqu9wcPHrQkWTVq1LCOHz/uWv7JJ59YkqzPPvvMteypp57KVSZJlr+/v3XgwAHXsm+//daSZL322muuZQMHDrQCAwOtX3/91bXshx9+sCpVqpTrmHmJjY21qlSpku/6c+fOWWFhYVarVq2sM2fOuJYvWrTIkmQ9+eSTlmVZ1okTJyxJ1rRp0/I91oIFCyxJ1pYtWwotF4CC0SwFoFBOp1OjRo3KtTwgIMD1+8mTJ3Xs2DFde+21On36tPbu3VvocW+55RaFhoa63l977bWSpJ9++qnQffv06aNGjRq53rdp00bBwcGufbOysrR8+XINHjxYtWvXdm3XuHFjxcTEFHr8oti6dauOHj2qe++9163D84ABA9SsWTN9/vnnksx98vf316pVq3TixIk8j5VTw7No0SJlZmZ6pHxARUW4AVCoOnXqyN/fP9fy3bt3a8iQIQoJCVFwcLBq1arl6oycmppa6HHr1avn9j4n6OQXAAraN2f/nH2PHj2qM2fOqHHjxrm2y2tZcRw6dEiS1LRp01zrmjVr5lrvdDr13HPPacmSJQoPD1ePHj3073//W0lJSa7te/bsqaFDh2ry5MmqWbOmBg0apFmzZikjI8MjZQUqEsINgEJdXEOTIyUlRT179tS3336rKVOm6LPPPtOyZcv03HPPSZKys7MLPa6vr2+ey60izFBxJfvaYcKECdq/f7/i4uJUuXJlPfHEE2revLm2b98uyXSSnjdvnjZs2KBx48bp119/1R133KEOHTowFB24TIQbAMWyatUq/f7774qPj9f48eP1xz/+UX369HFrZrJTWFiYKleurAMHDuRal9ey4qhfv74kad++fbnW7du3z7U+R6NGjfTggw/qyy+/1K5du3Tu3Dm98MILbtt07dpVzz77rLZu3aoPP/xQu3fv1pw5czxSXqCiINwAKJacmpOLa0rOnTunN954w64iufH19VWfPn20cOFCHTlyxLX8wIEDWrJkiUfO0bFjR4WFhenNN990az5asmSJ9uzZowEDBkgy8wKdPXvWbd9GjRopKCjItd+JEydy1Tq1a9dOkmiaAi4TQ8EBFEu3bt0UGhqq2NhY/f3vf5fD4dD7779fqpqFJk2apC+//FLdu3fXmDFjlJWVpenTp6tVq1basWNHkY6RmZmpZ555Jtfy6tWr695779Vzzz2nUaNGqWfPnho2bJhrKHiDBg10//33S5L279+v3r176+abb1aLFi1UqVIlLViwQMnJybr11lslSf/973/1xhtvaMiQIWrUqJFOnjyp//znPwoODlb//v09dk+AioBwA6BYatSooUWLFunBBx/U448/rtDQUN1+++3q3bu3+vbta3fxJEkdOnTQkiVL9NBDD+mJJ55QVFSUpkyZoj179hRpNJdkaqOeeOKJXMsbNWqke++9VyNHjlRgYKCmTp2qhx9+WFWqVNGQIUP03HPPuUZARUVFadiwYVqxYoXef/99VapUSc2aNdPHH3+soUOHSjIdijdv3qw5c+YoOTlZISEh6ty5sz788EM1bNjQY/cEqAh4thSACmfw4MHavXu3fvjhB7uLAsAL6HMDoFw7c+aM2/sffvhBixcvVq9evewpEACvo+YGQLkWGRmpkSNH6qqrrtKhQ4c0Y8YMZWRkaPv27WrSpIndxQPgBfS5AVCu9evXTx999JGSkpLkdDoVHR2tf/3rXwQboByj5gYAAJQr9LkBAADlCuEGAACUKxWuz012draOHDmioKAgORwOu4sDAACKwLIsnTx5UrVr15aPT8F1MxUu3Bw5ckRRUVF2FwMAABTD4cOHVbdu3QK3qXDhJigoSJK5OcHBwTaXBgAAFEVaWpqioqJc3+MFqXDhJqcpKjg4mHADAEAZU5QuJXQoBgAA5QrhBgAAlCuEGwAAUK5UuD43AAB7ZGdn69y5c3YXA6WYv79/ocO8i4JwAwDwunPnzungwYPKzs62uygoxXx8fNSwYUP5+/tf0XEINwAAr7IsS4mJifL19VVUVJRH/mWO8idnkt3ExETVq1fviibaJdwAALzq/PnzOn36tGrXrq3AwEC7i4NSrFatWjpy5IjOnz8vPz+/Yh+H+AwA8KqsrCxJuuKmBpR/OZ+RnM9McRFuAAAlguf5oTCe+owQbgAAQLlCuAEAoIQ0aNBAL7/8cpG3X7VqlRwOh1JSUrxWpvKIcAMAwCUcDkeBr0mTJhXruFu2bNE999xT5O27deumxMREhYSEFOt8RVXeQhSjpTwkI0NKTpZ8fKRCnsQOACjlEhMTXb//73//05NPPql9+/a5llWtWtX1u2VZysrKUqVKhX+l1qpV67LK4e/vr4iIiMvaB9TceMw330j160u9etldEgDAlYqIiHC9QkJC5HA4XO/37t2roKAgLVmyRB06dJDT6dTatWv1448/atCgQQoPD1fVqlXVqVMnLV++3O24lzZLORwOvf322xoyZIgCAwPVpEkTffrpp671l9aoxMfHq1q1alq6dKmaN2+uqlWrql+/fm5h7Pz58/r73/+uatWqqUaNGnr44YcVGxurwYMHF/t+nDhxQiNGjFBoaKgCAwMVExOjH374wbX+0KFDGjhwoEJDQ1WlShW1bNlSixcvdu07fPhw1apVSwEBAWrSpIlmzZpV7LIUBeHGQ3LmpGLyTQAomGVJp07Z87Isz13HI488oqlTp2rPnj1q06aN0tPT1b9/f61YsULbt29Xv379NHDgQCUkJBR4nMmTJ+vmm2/Wd999p/79+2v48OE6fvx4vtufPn1azz//vN5//32tWbNGCQkJeuihh1zrn3vuOX344YeaNWuW1q1bp7S0NC1cuPCKrnXkyJHaunWrPv30U23YsEGWZal///7KzMyUJI0dO1YZGRlas2aNdu7cqeeee85Vu/XEE0/o+++/15IlS7Rnzx7NmDFDNWvWvKLyFMqqYFJTUy1JVmpqqkePu3mzZUmWVb++Rw8LAGXemTNnrO+//946c+aMZVmWlZ5u/n9pxys9/fLLP2vWLCskJMT1/quvvrIkWQsXLix035YtW1qvvfaa6339+vWtl156yfVekvX444+73qenp1uSrCVLlrid68SJE66ySLIOHDjg2uf111+3wsPDXe/Dw8OtadOmud6fP3/eqlevnjVo0KB8y3npeS62f/9+S5K1bt0617Jjx45ZAQEB1scff2xZlmW1bt3amjRpUp7HHjhwoDVq1Kh8z32xSz8rF7uc729qbjyEmhsAqFg6duzo9j49PV0PPfSQmjdvrmrVqqlq1aras2dPoTU3bdq0cf1epUoVBQcH6+jRo/luHxgYqEaNGrneR0ZGurZPTU1VcnKyOnfu7Frv6+urDh06XNa1XWzPnj2qVKmSunTp4lpWo0YNNW3aVHv27JEk/f3vf9czzzyj7t2766mnntJ3333n2nbMmDGaM2eO2rVrp3/+859av359sctSVIQbDyHcAEDRBAZK6en2vDz59IcqVaq4vX/ooYe0YMEC/etf/9LXX3+tHTt2qHXr1oU+Cf3Sxww4HI4CHzCa1/aWJ9vbiuGuu+7STz/9pL/+9a/auXOnOnbsqNdee02SFBMTo0OHDun+++/XkSNH1Lt3b7dmNG8g3HgI4QYAisbhkKpUseflzUmS161bp5EjR2rIkCFq3bq1IiIi9PPPP3vvhHkICQlReHi4tmzZ4lqWlZWlb775ptjHbN68uc6fP69Nmza5lv3+++/at2+fWrRo4VoWFRWl0aNHa/78+XrwwQf1n//8x7WuVq1aio2N1QcffKCXX35ZM2fOLHZ5ioKh4B5CuAGAiq1JkyaaP3++Bg4cKIfDoSeeeKLAGhhvue+++xQXF6fGjRurWbNmeu2113TixIkiPdpg586dCgoKcr13OBxq27atBg0apLvvvltvvfWWgoKC9Mgjj6hOnToaNGiQJGnChAmKiYnRH/7wB504cUJfffWVmjdvLkl68skn1aFDB7Vs2VIZGRlatGiRa523EG48hHADABXbiy++qDvuuEPdunVTzZo19fDDDystLa3Ey/Hwww8rKSlJI0aMkK+vr+655x717dtXvr6+he7bo0cPt/e+vr46f/68Zs2apfHjx+uPf/yjzp07px49emjx4sWuJrKsrCyNHTtWv/zyi4KDg9WvXz+99NJLksxcPRMnTtTPP/+sgIAAXXvttZozZ47nL/wiDsvuhroSlpaWppCQEKWmpio4ONhjx92zR2rRQqpRQzp2zGOHBYAy7+zZszp48KAaNmyoypUr212cCic7O1vNmzfXzTffrKefftru4hSooM/K5Xx/U3PjIdTcAABKg0OHDunLL79Uz549lZGRoenTp+vgwYO67bbb7C5aiaFDsYcQbgAApYGPj4/i4+PVqVMnde/eXTt37tTy5cu93s+lNCk14Wbq1KlyOByaMGFCgdvNnTtXzZo1U+XKldW6dWvX9M52I9wAAEqDqKgorVu3TqmpqUpLS9P69etz9aUp70pFuNmyZYveeustt4mM8rJ+/XoNGzZMd955p7Zv367Bgwdr8ODB2rVrVwmVNH+EGwAASgfbw016erqGDx+u//znPwoNDS1w21deeUX9+vXTP/7xDzVv3lxPP/20rr76ak2fPr2ESps/wg0AFKyCjV9BMXjqM2J7uBk7dqwGDBigPn36FLrthg0bcm3Xt29fbdiwId99MjIylJaW5vbyBsINAOQtZwhyYTP1AjmfkaIMWy+IraOl5syZo2+++cZtJsWCJCUlKTw83G1ZeHi4kpKS8t0nLi5OkydPvqJyFgXhBgDyVqlSJQUGBuq3336Tn5+ffHxs/3c1SqHs7Gz99ttvCgwMVKVKVxZPbAs3hw8f1vjx47Vs2TKvznswceJEPfDAA673aWlpioqK8vh5CDcAkDeHw6HIyEgdPHhQhw4dsrs4KMV8fHxUr169Is2mXBDbws22bdt09OhRXX311a5lWVlZWrNmjaZPn66MjIxc1VIRERFKTk52W5acnKyIiIh8z+N0OuV0Oj1b+DwQbgAgf/7+/mrSpAlNUyiQv7+/R2r2bAs3vXv31s6dO92WjRo1Ss2aNdPDDz+cZ3tbdHS0VqxY4TZcfNmyZYqOjvZ2cQuV87ewLPPy5sPZAKAs8vHxYYZilAjbwk1QUJBatWrltqxKlSqqUaOGa/mIESNUp04dxcXFSZLGjx+vnj176oUXXtCAAQM0Z84cbd261etPFy2Ki4NmdrZ0hX2hAABAMZXqXl0JCQlKTEx0ve/WrZtmz56tmTNnqm3btpo3b54WLlyYKyTZ4dJwAwAA7MGDMz12XCkkxPx+9qxUAt18AACoMC7n+7tU19yUJdTcAABQOhBuPIRwAwBA6UC48RDCDQAApQPhxkMINwAAlA6EGw8h3AAAUDoQbjyEcAMAQOlAuPGQi2ckJtwAAGAfwo2HOBwXAg7hBgAA+xBuPIiHZwIAYD/CjQcRbgAAsB/hxoMINwAA2I9w40GEGwAA7Ee48SDCDQAA9iPceBDhBgAA+xFuPIhwAwCA/Qg3HkS4AQDAfoQbDyLcAABgP8KNBxFuAACwH+HGgwg3AADYj3DjQYQbAADsR7jxIMINAAD2I9x4EOEGAAD7EW48iHADAID9CDcelBNusrLsLQcAABUZ4caDqLkBAMB+hBsPItwAAGA/wo0H+fqan4QbAADsQ7jxIGpuAACwH+HGgwg3AADYj3DjQYQbAADsR7jxIMINAAD2I9x4EOEGAAD7EW48iHADAID9bA03M2bMUJs2bRQcHKzg4GBFR0dryZIl+W4fHx8vh8Ph9qpcuXIJlrhghBsAAOxXyc6T161bV1OnTlWTJk1kWZb++9//atCgQdq+fbtatmyZ5z7BwcHat2+f673D4Sip4haKcAMAgP1sDTcDBw50e//ss89qxowZ2rhxY77hxuFwKCIioiSKd9kINwAA2K/U9LnJysrSnDlzdOrUKUVHR+e7XXp6uurXr6+oqCgNGjRIu3fvLsFSFoxwAwCA/WytuZGknTt3Kjo6WmfPnlXVqlW1YMECtWjRIs9tmzZtqnfffVdt2rRRamqqnn/+eXXr1k27d+9W3bp189wnIyNDGRkZrvdpaWleuQ6JcAMAQGlge81N06ZNtWPHDm3atEljxoxRbGysvv/++zy3jY6O1ogRI9SuXTv17NlT8+fPV61atfTWW2/le/y4uDiFhIS4XlFRUd66FMINAAClgO3hxt/fX40bN1aHDh0UFxentm3b6pVXXinSvn5+fmrfvr0OHDiQ7zYTJ05Uamqq63X48GFPFT0Xwg0AAPazPdxcKjs7260ZqSBZWVnauXOnIiMj893G6XS6hprnvLyFcAMAgP1s7XMzceJExcTEqF69ejp58qRmz56tVatWaenSpZKkESNGqE6dOoqLi5MkTZkyRV27dlXjxo2VkpKiadOm6dChQ7rrrrvsvAwXwg0AAPazNdwcPXpUI0aMUGJiokJCQtSmTRstXbpUN9xwgyQpISFBPj4XKpdOnDihu+++W0lJSQoNDVWHDh20fv36fDsglzTCDQAA9nNYlmXZXYiSlJaWppCQEKWmpnq8iWrAAGnxYmnWLGnkSI8eGgCACu1yvr9LXZ+bsoyaGwAA7Ee48SDCDQAA9iPceBDhBgAA+xFuPIhwAwCA/Qg3HkS4AQDAfoQbDyLcAABgP8KNB+WEm6wse8sBAEBFRrjxIGpuAACwH+HGg3x9zU/CDQAA9iHceBA1NwAA2I9w40GEGwAA7Ee48SDCDQAA9iPceBDhBgAA+xFuPIhwAwCA/Qg3HkS4AQDAfoQbDyLcAABgP8KNBxFuAACwH+HGgwg3AADYj3DjQYQbAADsR7jxIMINAAD2I9x4EOEGAAD7EW48iHADAID9CDceRLgBAMB+hBsPItwAAGA/wo0HEW4AALAf4caDCDcAANiPcONBhBsAAOxHuPEgwg0AAPYj3HgQ4QYAAPsRbjyIcAMAgP0INx5EuAEAwH6EGw8i3AAAYD/CjQflhJusLHvLAQBARUa48SBqbgAAsJ+t4WbGjBlq06aNgoODFRwcrOjoaC1ZsqTAfebOnatmzZqpcuXKat26tRYvXlxCpS0c4QYAAPvZGm7q1q2rqVOnatu2bdq6dauuv/56DRo0SLt3785z+/Xr12vYsGG68847tX37dg0ePFiDBw/Wrl27SrjkefP1NT8JNwAA2MdhWZZldyEuVr16dU2bNk133nlnrnW33HKLTp06pUWLFrmWde3aVe3atdObb75ZpOOnpaUpJCREqampCg4O9li5JWnGDOnee6WhQ6V58zx6aAAAKrTL+f4uNX1usrKyNGfOHJ06dUrR0dF5brNhwwb16dPHbVnfvn21YcOGfI+bkZGhtLQ0t5e30CwFAID9bA83O3fuVNWqVeV0OjV69GgtWLBALVq0yHPbpKQkhYeHuy0LDw9XUlJSvsePi4tTSEiI6xUVFeXR8l+McAMAgP1sDzdNmzbVjh07tGnTJo0ZM0axsbH6/vvvPXb8iRMnKjU11fU6fPiwx459KcINAAD2q2R3Afz9/dW4cWNJUocOHbRlyxa98soreuutt3JtGxERoeTkZLdlycnJioiIyPf4TqdTTqfTs4XOB+EGAAD72V5zc6ns7GxlZGTkuS46OlorVqxwW7Zs2bJ8++iUNMINAAD2s7XmZuLEiYqJiVG9evV08uRJzZ49W6tWrdLSpUslSSNGjFCdOnUUFxcnSRo/frx69uypF154QQMGDNCcOXO0detWzZw5087LcCHcAABgP1vDzdGjRzVixAglJiYqJCREbdq00dKlS3XDDTdIkhISEuTjc6FyqVu3bpo9e7Yef/xxPfroo2rSpIkWLlyoVq1a2XUJbgg3AADYr9TNc+Nt3pzn5qOPpNtuk3r3lpYv9+ihAQCo0MrkPDflATU3AADYj3DjQYQbAADsR7jxIMINAAD2I9x4EOEGAAD7EW48iHADAID9CDceRLgBAMB+hBsPItwAAGA/wo0HEW4AALAf4caDCDcAANiPcONBhBsAAOxHuPEgwg0AAPYj3HgQ4QYAAPsRbjyIcAMAgP0INx5EuAEAwH6EGw/KCTdZWfaWAwCAioxw40HU3AAAYD/CjQf5+pqfhBsAAOxDuPEgam4AALAf4caDCDcAANiPcONBhBsAAOxHuPEgwg0AAPYj3HgQ4QYAAPsRbjyIcAMAgP0INx5EuAEAwH6EGw8i3AAAYD/CjQcRbgAAsB/hxoMINwAA2I9w40GEGwAA7Ee48SDCDQAA9iPceBDhBgAA+xFuPIhwAwCA/Qg3HkS4AQDAfoQbDyLcAABgP1vDTVxcnDp16qSgoCCFhYVp8ODB2rdvX4H7xMfHy+FwuL0qV65cQiUumM9Fd9Oy7CsHAAAVma3hZvXq1Ro7dqw2btyoZcuWKTMzUzfeeKNOnTpV4H7BwcFKTEx0vQ4dOlRCJS7YxeGG2hsAAOxRyc6Tf/HFF27v4+PjFRYWpm3btqlHjx757udwOBQREeHt4l22S8ONr699ZQEAoKIqVX1uUlNTJUnVq1cvcLv09HTVr19fUVFRGjRokHbv3p3vthkZGUpLS3N7eQs1NwAA2K/UhJvs7GxNmDBB3bt3V6tWrfLdrmnTpnr33Xf1ySef6IMPPlB2dra6deumX375Jc/t4+LiFBIS4npFRUV56xIINwAAlAIOyyodXV/HjBmjJUuWaO3atapbt26R98vMzFTz5s01bNgwPf3007nWZ2RkKCMjw/U+LS1NUVFRSk1NVXBwsEfKnuP0aalKFfN7evqF3wEAwJVJS0tTSEhIkb6/be1zk2PcuHFatGiR1qxZc1nBRpL8/PzUvn17HThwIM/1TqdTTqfTE8UsFDU3AADYz9ZmKcuyNG7cOC1YsEArV65Uw4YNL/sYWVlZ2rlzpyIjI71QwstzcbjJyrKvHAAAVGS21tyMHTtWs2fP1ieffKKgoCAlJSVJkkJCQhQQECBJGjFihOrUqaO4uDhJ0pQpU9S1a1c1btxYKSkpmjZtmg4dOqS77rrLtuvIQc0NAAD2szXczJgxQ5LUq1cvt+WzZs3SyJEjJUkJCQnyuSg1nDhxQnfffbeSkpIUGhqqDh06aP369WrRokVJFTtfhBsAAOxXajoUl5TL6ZBUHA6H+ZmcLIWFefzwAABUSJfz/V1qhoKXFzxfCgAAexFuPIxwAwCAvYoVbg4fPuw2ad7mzZs1YcIEzZw502MFK6sINwAA2KtY4ea2227TV199JUlKSkrSDTfcoM2bN+uxxx7TlClTPFrAsoZwAwCAvYoVbnbt2qXOnTtLkj7++GO1atVK69ev14cffqj4+HhPlq/MIdwAAGCvYoWbzMxM16y/y5cv15/+9CdJUrNmzZSYmOi50pVBhBsAAOxVrHDTsmVLvfnmm/r666+1bNky9evXT5J05MgR1ahRw6MFLGsINwAA2KtY4ea5557TW2+9pV69emnYsGFq27atJOnTTz91NVdVVIQbAADsVawZinv16qVjx44pLS1NoaGhruX33HOPAgMDPVa4sign3PBsKQAA7FGsmpszZ84oIyPDFWwOHTqkl19+Wfv27VNYBZ+W18/P/Dx/3t5yAABQURUr3AwaNEjvvfeeJCklJUVdunTRCy+8oMGDB7ueF1VRVfq/urDMTHvLAQBARVWscPPNN9/o2muvlSTNmzdP4eHhOnTokN577z29+uqrHi1gWZNTc0O4AQDAHsUKN6dPn1ZQUJAk6csvv9RNN90kHx8fde3aVYcOHfJoAcsamqUAALBXscJN48aNtXDhQh0+fFhLly7VjTfeKEk6evSoV560XZbQLAUAgL2KFW6efPJJPfTQQ2rQoIE6d+6s6OhoSaYWp3379h4tYFlDsxQAAPYq1lDwP//5z7rmmmuUmJjomuNGknr37q0hQ4Z4rHBlEc1SAADYq1jhRpIiIiIUERHhejp43bp1K/wEfhLNUgAA2K1YzVLZ2dmaMmWKQkJCVL9+fdWvX1/VqlXT008/rewKPjUvzVIAANirWDU3jz32mN555x1NnTpV3bt3lyStXbtWkyZN0tmzZ/Xss896tJBlCc1SAADYq1jh5r///a/efvtt19PAJalNmzaqU6eO7r333godbmiWAgDAXsVqljp+/LiaNWuWa3mzZs10/PjxKy5UWUazFAAA9ipWuGnbtq2mT5+ea/n06dPVpk2bKy5UWUazFAAA9ipWs9S///1vDRgwQMuXL3fNcbNhwwYdPnxYixcv9mgByxqapQAAsFexam569uyp/fv3a8iQIUpJSVFKSopuuukm7d69W++//76ny1im0CwFAIC9ij3PTe3atXN1HP7222/1zjvvaObMmVdcsLIqp+aGZikAAOxRrJob5I+aGwAA7EW48TDCDQAA9iLceBjNUgAA2Ouy+tzcdNNNBa5PSUm5krKUC9TcAABgr8sKNyEhIYWuHzFixBUVqKwj3AAAYK/LCjezZs3yVjnKDZqlAACwF31uPIyaGwAA7EW48TDCDQAA9rI13MTFxalTp04KCgpSWFiYBg8erH379hW639y5c9WsWTNVrlxZrVu3LlWPfKBZCgAAe9kablavXq2xY8dq48aNWrZsmTIzM3XjjTfq1KlT+e6zfv16DRs2THfeeae2b9+uwYMHa/Dgwdq1a1cJljx/1NwAAGAvh2VZlt2FyPHbb78pLCxMq1evVo8ePfLc5pZbbtGpU6e0aNEi17KuXbuqXbt2evPNNws9R1pamkJCQpSamqrg4GCPlT3HK69IEyZIw4ZJs2d7/PAAAFRIl/P9Xar63KSmpkqSqlevnu82GzZsUJ8+fdyW9e3bVxs2bMhz+4yMDKWlpbm9vImnggMAYK9SE26ys7M1YcIEde/eXa1atcp3u6SkJIWHh7stCw8PV1JSUp7bx8XFKSQkxPWKioryaLkvRbMUAAD2KjXhZuzYsdq1a5fmzJnj0eNOnDhRqamprtfhw4c9evxL5YQbOhQDAGCPy5rEz1vGjRunRYsWac2aNapbt26B20ZERCg5OdltWXJysiIiIvLc3ul0yul0eqyshaFZCgAAe9lac2NZlsaNG6cFCxZo5cqVatiwYaH7REdHa8WKFW7Lli1bpujoaG8V87LQLAUAgL1srbkZO3asZs+erU8++URBQUGufjMhISEKCAiQJI0YMUJ16tRRXFycJGn8+PHq2bOnXnjhBQ0YMEBz5szR1q1bNXPmTNuu42I0SwEAYC9ba25mzJih1NRU9erVS5GRka7X//73P9c2CQkJSkxMdL3v1q2bZs+erZkzZ6pt27aaN2+eFi5cWGAn5JJEsxQAAPayteamKFPsrFq1Kteyv/zlL/rLX/7ihRJdOZqlAACwV6kZLVVe8PgFAADsRbjxMGpuAACwF+HGwwg3AADYi3DjYTRLAQBgL8KNh1FzAwCAvQg3Hka4AQDAXoQbD6NZCgAAexFuPIyaGwAA7EW48TDCDQAA9iLceBjNUgAA2Itw42EX19wU4ekSAADAwwg3HpYTbiQpK8u+cgAAUFERbjys0kWPIqVpCgCAkke48bCLa27oVAwAQMkj3HjYxeGGmhsAAEoe4cbDfH0v/E7NDQAAJY9w42EOx4V+N4QbAABKHuHGC3KapmiWAgCg5BFuvICaGwAA7EO48QIewQAAgH0IN15AsxQAAPYh3HgBzVIAANiHcOMFNEsBAGAfwo0X8GRwAADsQ7jxAmpuAACwD+HGCwg3AADYh3DjBTRLAQBgH8KNF1BzAwCAfQg3XkC4AQDAPoQbL6BZCgAA+xBuvICaGwAA7EO48QLCDQAA9iHceAHNUgAA2MfWcLNmzRoNHDhQtWvXlsPh0MKFCwvcftWqVXI4HLleSUlJJVPgIqLmBgAA+9gabk6dOqW2bdvq9ddfv6z99u3bp8TERNcrLCzMSyUsHsINAAD2qWTnyWNiYhQTE3PZ+4WFhalatWqeL5CH5ISbc+fsLQcAABVRmexz065dO0VGRuqGG27QunXr7C5OLgEB5ufZs/aWAwCAisjWmpvLFRkZqTfffFMdO3ZURkaG3n77bfXq1UubNm3S1Vdfnec+GRkZysjIcL1PS0vzejkDA83P06e9fioAAHCJMhVumjZtqqZNm7red+vWTT/++KNeeuklvf/++3nuExcXp8mTJ5dUESURbgAAsFOZbJa6WOfOnXXgwIF810+cOFGpqamu1+HDh71eppxwc+aM108FAAAuUaZqbvKyY8cORUZG5rve6XTK6XSWYIku9Lmh5gYAgJJna7hJT093q3U5ePCgduzYoerVq6tevXqaOHGifv31V7333nuSpJdfflkNGzZUy5YtdfbsWb399ttauXKlvvzyS7suIU80SwEAYB9bw83WrVt13XXXud4/8MADkqTY2FjFx8crMTFRCQkJrvXnzp3Tgw8+qF9//VWBgYFq06aNli9f7naM0oBwAwCAfRyWZVl2F6IkpaWlKSQkRKmpqQoODvbKOebNk/7yF6lHD2n1aq+cAgCACuVyvr/LfIfi0og+NwAA2Idw4wU0SwEAYB/CjRcQbgAAsA/hxguY5wYAAPsQbryAPjcAANiHcOMFFzdLVayxaAAA2I9w4wU54SYrS8rMtLcsAABUNIQbL8gJNxL9bgAAKGmEGy/w85N8/u/O0u8GAICSRbjxAoeD4eAAANiFcOMlhBsAAOxBuPES5roBAMAehBsvYa4bAADsQbjxEpqlAACwB+HGSwg3AADYg3DjJfS5AQDAHoQbL6HPDQAA9iDceAnNUgAA2INw4yU0SwEAYA/CjZdQcwMAgD0IN15CnxsAAOxBuPESam4AALAH4cZL6HMDAIA9CDdeQs0NAAD2INx4CX1uAACwB+HGS6pWNT9PnrS3HAAAVDSEGy+pUcP8/P13e8sBAEBFQ7jxkpo1zU/CDQAAJYtw4yU5NTfHjkmWZW9ZAACoSAg3XpITbrKypNRUe8sCAEBFQrjxksqVL3QqpmkKAICSQ7jxooubpgAAQMkg3HhRTqdiwg0AACWHcONFhBsAAEqereFmzZo1GjhwoGrXri2Hw6GFCxcWus+qVat09dVXy+l0qnHjxoqPj/d6OYuL4eAAAJQ8W8PNqVOn1LZtW73++utF2v7gwYMaMGCArrvuOu3YsUMTJkzQXXfdpaVLl3q5pMVDnxsAAEpeJTtPHhMTo5iYmCJv/+abb6phw4Z64YUXJEnNmzfX2rVr9dJLL6lv377eKmax0SwFAEDJK1N9bjZs2KA+ffq4Levbt682bNiQ7z4ZGRlKS0tze5UUwg0AACWvTIWbpKQkhYeHuy0LDw9XWlqazpw5k+c+cXFxCgkJcb2ioqJKoqiS6HMDAIAdylS4KY6JEycqNTXV9Tp8+HCJnZs+NwAAlDxb+9xcroiICCUnJ7stS05OVnBwsAICAvLcx+l0yul0lkTxcqFZCgCAklemam6io6O1YsUKt2XLli1TdHS0TSUq2MXNUjw8EwCAkmFruElPT9eOHTu0Y8cOSWao944dO5SQkCDJNCmNGDHCtf3o0aP1008/6Z///Kf27t2rN954Qx9//LHuv/9+O4pfqFq1JB8f8/DMxES7SwMAQMVga7jZunWr2rdvr/bt20uSHnjgAbVv315PPvmkJCkxMdEVdCSpYcOG+vzzz7Vs2TK1bdtWL7zwgt5+++1SOQxckvz8pIYNze8//GBvWQAAqCgcllWxGkzS0tIUEhKi1NRUBQcHe/18/ftLS5ZIM2dKd9/t9dMBAFAuXc73d5nqc1MW/eEP5uf+/faWAwCAioJw42U54YZmKQAASgbhxsuaNDE/qbkBAKBkEG68LKfm5sABM2oKAAB4F+HGy6KiJKdTysyUDh2yuzQAAJR/hBsv8/G50DS1d6+9ZQEAoCIg3JSA/5vGR+vW2VsOAAAqAsJNCbj+evNz5Up7ywEAQEVAuCkB111nfm7ZIp08aW9ZAAAo7wg3JaB+femqq8xoqa+/trs0AACUb4SbEpJTe7N0qb3lAACgvCPclJAhQ8zPd9+Vjh+3tywAAJRnhJsS0r+/1LatlJ4uvfqq3aUBAKD8ItyUEIdDeuwx8/u//y1t22ZveQAAKK8INyVo6FCpXz/pzBnpT3+Sdu2yu0QAAJQ/hJsS5OMjzZkjtWwpHTkide0qvfOO9Ntv0rlzdpcOAIDyoZLdBahoQkKk1aulW26RVqyQ7rrLLK9eXXr4YaldO7PN//4nffut1KKFNGGC1KiRnaUGAKDscFiWZdldiJKUlpamkJAQpaamKjg42LZynD8vvfKKNHly4RP7BQdL06dLqanSvHnSTTdJlSqZB3GGhEh//7tUtWrJlBsAADtczvc34cZmWVnmFR8vzZ0rJSebZqrmzU3tznvvSevXF3yM5s2lGjWkX36RGjaUZs0yEwdKJhAtXWr6+bRqZWqGfH29fVUAAHgW4aYApS3cFCYzU3r+eelf/zK//+1v0tatphnrqqtM81Vysvs+1aqZ/j19+phHPhw8eGFd27bSggVSdrZpBuvYUXrwQcnPz6w/dszUClWrVlJXCABA4Qg3BShr4SZHaqrpdFyrlvvyX36RPvxQCg83tTX33ivt3eu+TZ06UtOm0ubNZp6dgAATYHKawyIiTCfnHj2kqVNNk9nQodJbb5kmMQAA7Ea4KUBZDTdFlZIiffmlFBRkamb8/KRPPzUB59dfTWjZtMls27atdPhw/jMm33CD9MQTpjkrMNB0gN6yxTRt/fGPpnkrM1N64w0TkP7yF1NjBACApxFuClDew82lLMtMIJgjO1vas0dKSpKuucbUBu3aJS1eLL3/vnTbbVLfvtKAAdKpU2afwEDTYfno0QvHadrUBJ9Zs0zokaQ2bcw8Pu+8Y5rN+vSROnSQYmNNTREAAMVFuClARQs3xbV8uTRxoglBv/xiltWuLV17rVn3++8Xtq1SxdTipKXlfazRo02gSk2VGjQwYWvvXqluXTOZ4e2308cHAFAwwk0BCDeXx7KktWtNX50+fUwzV0qKeZTEmjWmyerRR03wee456YsvpPvuM0PU16yRXn+98HNUrix162aCVM2a5pjbt0v332+aunKkpZnmtpyaqEtrpQAA5RfhpgCEm5J1++2mw3NUlDRtmqnxOX9eatzY1ObMmiXt3p3//sOHm+HtCxea5rP27U1fotWrpQ8+MMFq9Gjp44+l77+Xbr7ZNL0FBZk+RefPS5MmmWMMH26GxV93HR2lAaCsIdwUgHBTss6cMcPVb7zR1O5cyrJMaFm3znR63rvXTE5oWaajclEEBJjzXMzXV/p//09KSDCTHDocpnZo3ToTtP77XxNy8nPy5IVRZQAA+xFuCkC4KTs2bjQ1O2lpUv/+UufO0osvmjAUGmr67Lz1ltm2ZUszmeH/+3+m6ezcOfOzUqXcwUeSnE4zv89PP0kZGVKvXtJf/2o6T7/6quksHREhDRtmjnHbbSYUbdok7d9vht737Wv6GwEAvI9wUwDCTflhWWZCwmrVTC2Mw2FqXCpVku64wzykVDLP59q/3zRR3XefGf6+cGHu4/n4mHCTnl6080dFmQkWhw41s0z7+xe93BL9hQDgchBuCkC4qRgsy3RKXr3a9MP58ktT6/Liiyb8/P3vptbm+utNqHn/fVMjJJnmsaeeks6eNf2BEhPNXEGSaVq7+mrzUNPDh80yX1/zmjDBHG/pUrNu0CBpxw7T1Fatmulcffy4eTZYQoJ5vMakSVJYmDnP55+bGqdrrjFNaEUNSwBQERBuCkC4QX4SEkz4aNMm92SEx4+bUJTzkTl92nSQfvXV/CdBvFTnziZQHTt2YVnt2mZOofh4MyFiDh8f03n6uedMYHr9dRPYxowxASggwHSo3r/fDNkPDDT7WZZpxgsJKfZtAIBSiXBTAMINPOnsWfNsr82bpUceMTU43bqZB5l+9pkZsdW9uxnVldP35+qrpccfN8Pp9+y5cKzrrjP9fJYudQ9MnTub40umlufoUdPfKGf+oU6dzLmyskxt0IYN0rPPSgcOmFFid91lnib/5z+b/TMzzbD7qKiSuUcA4AmEmwIQbmCHzz4zTVC33iqNH2+anNLTTYfpw4fNTM4332z64ViWCS7//veFGhsp96gwHx8zc3Ramgkq6enSiRO5z12zpqktuuoq0z/ogw9MU9tdd0k7d5p1/fubh7NWrWr22bHD9Es6e9aEs3r1zHY9elx4yOru3WZE29/+Zmq7AMCbyly4ef311zVt2jQlJSWpbdu2eu2119S5c+c8t42Pj9eoUaPcljmdTp09e7ZI5yLcoCxZudL0/7nlFvM8r40bTU3O6tXmERg1a0oxMaa5SzJNWV27SjNmmCavI0eKfq5rrjEPZt2/P/+5h+rWNc8Wu/ZaE7wSEkyT2EcfmfXx8SYkNWt2YZ+sLDO5Y6tW5uGuF9u40fSFGj3ajGADgPyUqXDzv//9TyNGjNCbb76pLl266OWXX9bcuXO1b98+hYWF5do+Pj5e48eP1759+1zLHA6HwsPDi3Q+wg3Km99+k55+2jSBjRxpmsYOHTIdox9+WPr6a9M/6I03TBCJiTE1Mg88YOYfGjxYuvNO91FiDoc0ZIhpXps/39QYBQS4P3ZDMv2Qzp+/MCT+1CkpMtKct3Vrs+6hh0wNUViY6ci9cqUZZXbsmAls585J//iHqamyLPPiAawALlWmwk2XLl3UqVMnTZ8+XZKUnZ2tqKgo3XfffXrkkUdybR8fH68JEyYoJSWlWOcj3ADGxY+v+Ppr0w+oWzfT96dZM9MUlbNddrbpq7N8uanZeeEF0xy2fLn0z3+aR21cjtBQE6ZyOlE7HGbo/quvmifP169vmub+8AfTp6hWLTMX0euvm07dt90mNWpk1ktmdupPP5U6djRzEDVrdqH5DED5UGbCzblz5xQYGKh58+Zp8ODBruWxsbFKSUnRJ598kmuf+Ph43XXXXapTp46ys7N19dVX61//+pdatmyZ5zkyMjKUkZHhep+WlqaoqCjCDXAFMjNNbU5wsHTwoGmqysgwnaEXLTLLVq408w6NHSuNGGGeGH/8uOkf9PPP5jh/+Yt5ttj77xevHJMmmUA2dKg5V47mzc1s1KGhpsnu1VfNUP+aNc0+FzebXWz6dOmZZ8wM1n37Fq9MALzjcsKNrZPLHzt2TFlZWbmalMLDw7V3794892natKneffddtWnTRqmpqXr++efVrVs37d69W3Xr1s21fVxcnCZPnuyV8gMVlZ/fhZqRhg3NvD+ZmVKTJlLPnmZ5Zqap8cnpS7NnjxlC7+dnmtF69jTh5uxZ07w1a5ZpBvvoI9MsFRkpffONCUKbNpl5gAYMMM1b69eb+YMmTbpQphYtzLESE825hg83/Xxeesk0j+X48ssL4SY62sxP9M03pp/R44+bmqrYWLPs2DHTf+nDD811BgdL1aubsJTTf+j4cTOZ5Nmzpnx165rr9NS/nX77zYQyJn0Eis7WmpsjR46oTp06Wr9+vaKjo13L//nPf2r16tXatGlTocfIzMxU8+bNNWzYMD399NO51lNzA5QNmzebWp3IyLzXnz1ranlyTJggvfKK+f2uu0yti9Mpbd1qanMunjeob1/zKI0ZM0xQKoifn/u+ealWzdRU9e5tmun27zfLAwJMqDlxQnrySTPkvm5dc7wFC8y1+fiY6QAefdSMmtu2Tfr1V/Ow18hIE7w2bTKj6L76yvSVuusuaebM3AHHssy+efy7Dih3ykzNTc2aNeXr66vk5GS35cnJyYqIiCjSMfz8/NS+fXsdOHAgz/VOp1NOhmEApV4+AyRdLg42kukk3bKlqbHp3v3C8o4dzTD2t94yNTZ33206TUumpujjjy/M/vzSSyYMhYeb87dqJQ0caDpdp6aajtK9epnHdqSkmM7Pzz13YTTZokXmZ+3apgP3li0Xhus//nj+1/LZZ9KyZVLjxtJ777mvGzhQWrzYjDLL8fbbpg9UTkhq1swEnldflZYsMQGoXz8TrsLDTRPd009LXbqYJjvJ9HEKDCxaZ+3sbBPeAgIK3xYojUpFh+LOnTvrtddek2Q6FNerV0/jxo3Ls0PxpbKystSyZUv1799fL774YqHb06EYQI6sLDOnT8uW7uEpK8s0LeX1ZPjUVBNqQkKkceNMCFi50gSO+fPNaLWcTtfXX28eA5Kaaob05zz/bMoU9z5CV19tQtHevRfmNapb19TmtGt3Yah9fvz8TJkDA81Q/NmzzTB7Hx8zr9Lnn5syVa1qpgu4+moTrDZulH780VxrzZpmZNuxY+YYP/0krVhhpgj45BPTafuBBy40M+7YYcJc374XOp9f6vffzb0ZMODCLNpAcZWZDsWSGQoeGxurt956S507d9bLL7+sjz/+WHv37lV4eLhGjBihOnXqKC4uTpI0ZcoUde3aVY0bN1ZKSoqmTZumhQsXatu2bWrRokWh5yPcAPCUrCxTm3JprdLF8hre/sMPptNyQoJ5Gv0NN5jls2ebGpkbb5TmzjWhJTvbDJPPGV9xzz2mCe/NN8371q3NUHtv6NXLhLacc912m5lfae1aU/MkmTB3/fVmGoLBg80IunPnTNPdvfeaa+3QwTThNWpkyi+ZsBQamn8zJHCpMhVuJGn69OmuSfzatWunV199VV26dJEk9erVSw0aNFB8fLwk6f7779f8+fOVlJSk0NBQdejQQc8884zat29fpHMRbgCUZpf2LcrP5s1m2zZtTPNat27m0RtLl5rOz/ffb15795omqj//2Twq5JtvzGv/flMr1LGjqSH69FPTb6lWLRNApk690DTmcJhXdrZ7GS4NVlWrus+XlJcZM0wt1QMPmNqcMWPMeX77zTxapEEDE6ri403zYb9+ZpuDB02NUtOm5lqOHTMd2KtXz30OyzIj8ObNM3Mr9elT+P2UzDE3bjTnvLTGDvYrc+GmJBFuAFQUOWGkqJMiHj5smqcCAqTbbzejxCpVMv2UfvvNTLZ4zTXmcR29epnmvP37zSM9pk41tVh16phwdeiQabp79VVTI/XLL2am6uKoX9+Mgjt3Lve6m24yAWftWhPSqlUznay///7CNo8+aob4r1kjPfigNGqUmaJAMs2DAQHm+D17mhDVt6+55qAgE8QCA03frD17TBNcjx5m33/9ywSioUNNv6+LO3yfO2fCVffuuWfmRvEQbgpAuAGAwh06ZGp+7rzT9JkpzJ49pn/RTTflXfOUnW2atf73PxMaHn/c1BKtW2emAAgLM81Us2aZ2pP77jOdo5955kJtUK9epikvIcHsm5iYf3n8/U1I+ewz8/7OO02fqJznrz34oAkteQyylWQeY+JwSN99Z0bavf32hXV9+pjJLh977MKyzp1NbVNkpHk8yhtvmGkHatc29yUszAQlPz8TGLOyTO3SjBnSn/5k+jv98IMJjJeOiktKMqGtKDV65RnhpgCEGwCwh2WZR3RUqVLwvD0XN80lJppwcNVVpklKMsGgUiXTJPbAAyYwjBljjp+SYsJShw5SRISZLmDChAvHrlfPhKO8NG8uPfusCUJ5PYS2a1dTlotmF1H37mY4/9mzZuReSkruZ7o1bWrCzdq1powDB5opAz78MHe5BgwwtWZdu5owNH686d/Urp0JTcHBF2YXP3rUzPn088+mT1Pr1uZYv/9upiJo2dLURiUnmzCWnS0dOGBq8po0Me/Pnzf3L+fvcf68+d3X17w/eNDU5gUFFb3J1FsINwUg3ABAxTJ3rhnhdu6c6Z/05ZcmwJw/b2qQxo83X+YhIeaLf+NGU+sTHGxGl332mWnKeucd82V/221mLqK2bc2IseRk8/vx4+Z84eGmL9OwYSZ0XTwy7mK+vqYWaOnS3Ov8/EyQuHjfVq1M7dbmzWbUWlqa+z6RkabTdmZm7vmc+vUzk20mJpprfOYZU7v0yy9m5NysWSbADBhgwuHy5SZU3X67CYnt2pn3U6eacOrvbya79PU1YevVV80Iu6lTTfDMzPT8CDnCTQEINwCArVtNM9Ctt+Zdi5SWZgKEv78ZLt+o0YXtzp0zcxFde60JApL07rsmMAUEmMCT80SgX34xw+FPnzaBKTHR1NwcP26aumJjpb/9zdS23Hefmfdo925TGySZjuL//KcJSjlzKF2sdWsTapYvd+/wHRRkwkVKiqnpypml29fXfQ6li+U8CFcyNWUJCe6ze1+qWTNz7jNnTCCUTIg6eNCMnJs6Nf99i4NwUwDCDQDA0yzLBJw//MGEnoIkJ5umq/wG+VqWqWn6/nsz2isgwDQnrVxpanM6dTJBKzTU1BJJponql19MU9eyZaaG6vrrTThZt06aPNlMTjlmjFm+dauZ72j+fPMYk//+15y3QwcT5nKeTX3LLSYQ/fKLCVKvv25qck6dcq9VqlTJ7J8TnKKiTHj05By6hJsCEG4AABXZsWMm1Pz5zxeG0qemSvv2mea1AwekVasu9Nm5uGZr714TXE6eNKPPsrJMLdQ115gaqylTzHPdHnnENPN5EuGmAIQbAADKnsv5/i7i7AcAAABlA+EGAACUK4QbAABQrhBuAABAuUK4AQAA5QrhBgAAlCuEGwAAUK4QbgAAQLlCuAEAAOUK4QYAAJQrhBsAAFCuEG4AAEC5QrgBAADlCuEGAACUK5XsLkBJsyxLknl0OgAAKBtyvrdzvscLUuHCzcmTJyVJUVFRNpcEAABcrpMnTyokJKTAbRxWUSJQOZKdna0jR44oKChIDofDY8dNS0tTVFSUDh8+rODgYI8dt6yo6NcvcQ8q+vVL3AOJe1DRr1/y3j2wLEsnT55U7dq15eNTcK+aCldz4+Pjo7p163rt+MHBwRX2Ay1x/RL3oKJfv8Q9kLgHFf36Je/cg8JqbHLQoRgAAJQrhBsAAFCuEG48xOl06qmnnpLT6bS7KLao6NcvcQ8q+vVL3AOJe1DRr18qHfegwnUoBgAA5Rs1NwAAoFwh3AAAgHKFcAMAAMoVwg0AAChXCDce8Prrr6tBgwaqXLmyunTpos2bN9tdJK+YNGmSHA6H26tZs2au9WfPntXYsWNVo0YNVa1aVUOHDlVycrKNJb5ya9as0cCBA1W7dm05HA4tXLjQbb1lWXryyScVGRmpgIAA9enTRz/88IPbNsePH9fw4cMVHBysatWq6c4771R6enoJXsWVKewejBw5Mtfnol+/fm7blOV7EBcXp06dOikoKEhhYWEaPHiw9u3b57ZNUT77CQkJGjBggAIDAxUWFqZ//OMfOn/+fEleSrEU5fp79eqV6zMwevRot23K6vVL0owZM9SmTRvXpHTR0dFasmSJa315/vvnKOwelLrPgIUrMmfOHMvf39969913rd27d1t33323Va1aNSs5OdnuonncU089ZbVs2dJKTEx0vX777TfX+tGjR1tRUVHWihUrrK1bt1pdu3a1unXrZmOJr9zixYutxx57zJo/f74lyVqwYIHb+qlTp1ohISHWwoULrW+//db605/+ZDVs2NA6c+aMa5t+/fpZbdu2tTZu3Gh9/fXXVuPGja1hw4aV8JUUX2H3IDY21urXr5/b5+L48eNu25Tle9C3b19r1qxZ1q5du6wdO3ZY/fv3t+rVq2elp6e7tinss3/+/HmrVatWVp8+fazt27dbixcvtmrWrGlNnDjRjku6LEW5/p49e1p3332322cgNTXVtb4sX79lWdann35qff7559b+/futffv2WY8++qjl5+dn7dq1y7Ks8v33z1HYPShtnwHCzRXq3LmzNXbsWNf7rKwsq3bt2lZcXJyNpfKOp556ymrbtm2e61JSUiw/Pz9r7ty5rmV79uyxJFkbNmwooRJ616Vf7NnZ2VZERIQ1bdo017KUlBTL6XRaH330kWVZlvX9999bkqwtW7a4tlmyZInlcDisX3/9tcTK7in5hZtBgwblu095uwdHjx61JFmrV6+2LKton/3FixdbPj4+VlJSkmubGTNmWMHBwVZGRkbJXsAVuvT6Lct8sY0fPz7ffcrT9ecIDQ213n777Qr3979Yzj2wrNL3GaBZ6gqcO3dO27ZtU58+fVzLfHx81KdPH23YsMHGknnPDz/8oNq1a+uqq67S8OHDlZCQIEnatm2bMjMz3e5Fs2bNVK9evXJ7Lw4ePKikpCS3aw4JCVGXLl1c17xhwwZVq1ZNHTt2dG3Tp08f+fj4aNOmTSVeZm9ZtWqVwsLC1LRpU40ZM0a///67a115uwepqamSpOrVq0sq2md/w4YNat26tcLDw13b9O3bV2lpadq9e3cJlv7KXXr9OT788EPVrFlTrVq10sSJE3X69GnXuvJ0/VlZWZozZ45OnTql6OjoCvf3l3Lfgxyl6TNQ4R6c6UnHjh1TVlaW2x9LksLDw7V3716bSuU9Xbp0UXx8vJo2barExERNnjxZ1157rXbt2qWkpCT5+/urWrVqbvuEh4crKSnJngJ7Wc515fX3z1mXlJSksLAwt/WVKlVS9erVy8196devn2666SY1bNhQP/74ox599FHFxMRow4YN8vX1LVf3IDs7WxMmTFD37t3VqlUrSSrSZz8pKSnPz0nOurIir+uXpNtuu03169dX7dq19d133+nhhx/Wvn37NH/+fEnl4/p37typ6OhonT17VlWrVtWCBQvUokUL7dixo8L8/fO7B1Lp+wwQblBkMTExrt/btGmjLl26qH79+vr4448VEBBgY8lgp1tvvdX1e+vWrdWmTRs1atRIq1atUu/evW0smeeNHTtWu3bt0tq1a+0uii3yu/577rnH9Xvr1q0VGRmp3r1768cff1SjRo1Kuphe0bRpU+3YsUOpqamaN2+eYmNjtXr1aruLVaLyuwctWrQodZ8BmqWuQM2aNeXr65urV3xycrIiIiJsKlXJqVatmv7whz/owIEDioiI0Llz55SSkuK2TXm+FznXVdDfPyIiQkePHnVbf/78eR0/frzc3perrrpKNWvW1IEDBySVn3swbtw4LVq0SF999ZXq1q3rWl6Uz35ERESen5OcdWVBftefly5dukiS22egrF+/v7+/GjdurA4dOiguLk5t27bVK6+8UmH+/lL+9yAvdn8GCDdXwN/fXx06dNCKFStcy7Kzs7VixQq3dsjyKj09XT/++KMiIyPVoUMH+fn5ud2Lffv2KSEhodzei4YNGyoiIsLtmtPS0rRp0ybXNUdHRyslJUXbtm1zbbNy5UplZ2e7/uMvb3755Rf9/vvvioyMlFT274FlWRo3bpwWLFiglStXqmHDhm7ri/LZj46O1s6dO91C3rJlyxQcHOyq1i+tCrv+vOzYsUOS3D4DZfX685Odna2MjIxy//cvSM49yIvtnwGPd1GuYObMmWM5nU4rPj7e+v7776177rnHqlatmluP8PLiwQcftFatWmUdPHjQWrdundWnTx+rZs2a1tGjRy3LMsMh69WrZ61cudLaunWrFR0dbUVHR9tc6itz8uRJa/v27db27dstSdaLL75obd++3Tp06JBlWWYoeLVq1axPPvnE+u6776xBgwblORS8ffv21qZNm6y1a9daTZo0KTPDoC2r4Htw8uRJ66GHHrI2bNhgHTx40Fq+fLl19dVXW02aNLHOnj3rOkZZvgdjxoyxQkJCrFWrVrkNcz19+rRrm8I++znDYG+88UZrx44d1hdffGHVqlWrTAwFLuz6Dxw4YE2ZMsXaunWrdfDgQeuTTz6xrrrqKqtHjx6uY5Tl67csy3rkkUes1atXWwcPHrS+++4765FHHrEcDof15ZdfWpZVvv/+OQq6B6XxM0C48YDXXnvNqlevnuXv72917tzZ2rhxo91F8opbbrnFioyMtPz9/a06depYt9xyi3XgwAHX+jNnzlj33nuvFRoaagUGBlpDhgyxEhMTbSzxlfvqq68sSblesbGxlmWZ4eBPPPGEFR4ebjmdTqt3797Wvn373I7x+++/W8OGDbOqVq1qBQcHW6NGjbJOnjxpw9UUT0H34PTp09aNN95o1apVy/Lz87Pq169v3X333bnCfVm+B3lduyRr1qxZrm2K8tn/+eefrZiYGCsgIMCqWbOm9eCDD1qZmZklfDWXr7DrT0hIsHr06GFVr17dcjqdVuPGja1//OMfbnOcWFbZvX7Lsqw77rjDql+/vuXv72/VqlXL6t27tyvYWFb5/vvnKOgelMbPgMOyLMvz9UEAAAD2oM8NAAAoVwg3AACgXCHcAACAcoVwAwAAyhXCDQAAKFcINwAAoFwh3AAAgHKFcAOgQnI4HFq4cKHdxQDgBYQbACVu5MiRcjgcuV79+vWzu2gAyoFKdhcAQMXUr18/zZo1y22Z0+m0qTQAyhNqbgDYwul0KiIiwu0VGhoqyTQZzZgxQzExMQoICNBVV12lefPmue2/c+dOXX/99QoICFCNGjV0zz33KD093W2bd999Vy1btpTT6VRkZKTGjRvntv7YsWMaMmSIAgMD1aRJE3366aeudSdOnNDw4cNVq1YtBQQEqEmTJrnCGIDSiXADoFR64oknNHToUH377bcaPny4br31Vu3Zs0eSdOrUKfXt21ehoaHasmWL5s6dq+XLl7uFlxkzZmjs2LG65557tHPnTn366adq3Lix2zkmT56sm2++Wd9995369++v4cOH6/jx467zf//991qyZIn27NmjGTNmqGbNmiV3AwAUn1cexwkABYiNjbV8fX2tKlWquL2effZZy7LMk6hHjx7ttk+XLl2sMWPGWJZlWTNnzrRCQ0Ot9PR01/rPP//c8vHxcT2RvHbt2tZjjz2WbxkkWY8//rjrfXp6uiXJWrJkiWVZljVw4EBr1KhRnrlgACWKPjcAbHHddddpxowZbsuqV6/u+j06OtptXXR0tHbs2CFJ2rNnj9q2basqVaq41nfv3l3Z2dnat2+fHA6Hjhw5ot69exdYhjZt2rh+r1KlioKDg3X06FFJ0pgxYzR06FB98803uvHGGzV48GB169atWNcKoGQRbgDYokqVKrmaiTwlICCgSNv5+fm5vXc4HMrOzpYkxcTE6NChQ1q8eLGWLVum3r17a+zYsXr++ec9Xl4AnkWfGwCl0saNG3O9b968uSSpefPm+vbbb3Xq1CnX+nXr1snHx0dNmzZVUFCQGjRooBUrVlxRGWrVqqXY2Fh98MEHevnllzVz5swrOh6AkkHNDQBbZGRkKCkpyW1ZpUqVXJ12586dq44dO+qaa67Rhx9+qM2bN+udd96RJA0fPlxPPfWUYmNjNWnSJP3222+677779Ne//lXh4eGSpEmTJmn06NEKCwtTTEyMTp48qXXr1um+++4rUvmefPJJdejQQS1btlRGRoYWLVrkClcASjfCDQBbfPHFF4qMjHRb1rRpU+3du1eSGck0Z84c3XvvvYqMjNRHH32kFi1aSJICAwO1dOlSjR8/Xp06dVJgYKCGDh2qF1980XWs2NhYnT17Vi+99JIeeugh1axZU3/+85+LXD5/f39NnDhRP//8swICAnTttddqzpw5HrhyAN7msCzLsrsQAHAxh8OhBQsWaPDgwXYXBUAZRJ8bAABQrhBuAABAuUKfGwClDq3lAK4ENTcAAKBcIdwAAIByhXADAADKFcINAAAoVwg3AACgXCHcAACAcoVwAwAAyhXCDQAAKFcINwAAoFz5/yAGBPoE3TruAAAAAElFTkSuQmCC\n" + }, + "metadata": {} + } + ] + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "dXPARCwlSpqV" + }, + "execution_count": 28, + "outputs": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "machine_shape": "hm", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/baselines/fedavgm/_static/Figure6_cifar10_num-rounds=1000_concentration=1.png b/baselines/fedavgm/_static/Figure6_cifar10_num-rounds=1000_concentration=1.png new file mode 100644 index 000000000000..1668474caed6 Binary files /dev/null and b/baselines/fedavgm/_static/Figure6_cifar10_num-rounds=1000_concentration=1.png differ diff --git a/baselines/fedavgm/_static/concentration_cifar10.png b/baselines/fedavgm/_static/concentration_cifar10.png new file mode 100644 index 000000000000..0755ef8d66be Binary files /dev/null and b/baselines/fedavgm/_static/concentration_cifar10.png differ diff --git a/baselines/fedavgm/_static/concentration_cifar10_v2.png b/baselines/fedavgm/_static/concentration_cifar10_v2.png new file mode 100644 index 000000000000..bd3b9db1ff11 Binary files /dev/null and b/baselines/fedavgm/_static/concentration_cifar10_v2.png differ diff --git a/baselines/fedavgm/_static/custom-fedavgm_vs_fedavgm_rounds=1000_fmnist.png b/baselines/fedavgm/_static/custom-fedavgm_vs_fedavgm_rounds=1000_fmnist.png new file mode 100644 index 000000000000..042527a3ac21 Binary files /dev/null and b/baselines/fedavgm/_static/custom-fedavgm_vs_fedavgm_rounds=1000_fmnist.png differ diff --git a/baselines/fedavgm/_static/fedavgm_vs_fedavg_rounds=10000_cifar10.png b/baselines/fedavgm/_static/fedavgm_vs_fedavg_rounds=10000_cifar10.png new file mode 100644 index 000000000000..771e13514363 Binary files /dev/null and b/baselines/fedavgm/_static/fedavgm_vs_fedavg_rounds=10000_cifar10.png differ diff --git a/baselines/fedavgm/_static/fedavgm_vs_fedavg_rounds=10000_cifar10_w_1e-9.png b/baselines/fedavgm/_static/fedavgm_vs_fedavg_rounds=10000_cifar10_w_1e-9.png new file mode 100644 index 000000000000..005aabbf6752 Binary files /dev/null and b/baselines/fedavgm/_static/fedavgm_vs_fedavg_rounds=10000_cifar10_w_1e-9.png differ diff --git a/baselines/fedavgm/_static/fedavgm_vs_fedavg_rounds=1000_fmnist.png b/baselines/fedavgm/_static/fedavgm_vs_fedavg_rounds=1000_fmnist.png new file mode 100644 index 000000000000..313c8299336f Binary files /dev/null and b/baselines/fedavgm/_static/fedavgm_vs_fedavg_rounds=1000_fmnist.png differ diff --git a/baselines/fedavgm/conf-colab.sh b/baselines/fedavgm/conf-colab.sh new file mode 100644 index 000000000000..822fe2f273e1 --- /dev/null +++ b/baselines/fedavgm/conf-colab.sh @@ -0,0 +1,24 @@ +#!/bin/bash +# Shellscript to configure the environment on the Google Colab terminal + +# fix issue with ctypes on Colab instance +apt-get update +apt-get install -y libffi-dev + +# Install pyenv +curl https://pyenv.run | bash +export PYENV_ROOT="$HOME/.pyenv" +command -v pyenv >/dev/null || export PATH="$PYENV_ROOT/bin:$PATH" +eval "$(pyenv init -)" + +# this version is specific to the FedAvgM baseline +pyenv install 3.10.6 +pyenv global 3.10.6 + +# install Poetry +curl -sSL https://install.python-poetry.org | python3 - +export PATH="/root/.local/bin:$PATH" + +# install and set environment with Poetry +poetry install +poetry shell diff --git a/baselines/fedavgm/fedavgm/__init__.py b/baselines/fedavgm/fedavgm/__init__.py new file mode 100644 index 000000000000..a5e567b59135 --- /dev/null +++ b/baselines/fedavgm/fedavgm/__init__.py @@ -0,0 +1 @@ +"""Template baseline package.""" diff --git a/baselines/fedavgm/fedavgm/client.py b/baselines/fedavgm/fedavgm/client.py new file mode 100644 index 000000000000..6500bdc9c737 --- /dev/null +++ b/baselines/fedavgm/fedavgm/client.py @@ -0,0 +1,70 @@ +"""Define the Flower Client and function to instantiate it.""" + +import math + +import flwr as fl +from hydra.utils import instantiate +from keras.utils import to_categorical + + +class FlowerClient(fl.client.NumPyClient): + """Standard Flower client.""" + + # pylint: disable=too-many-arguments + def __init__(self, x_train, y_train, x_val, y_val, model, num_classes) -> None: + # local model + self.model = instantiate(model) + + # local dataset + self.x_train, self.y_train = x_train, to_categorical( + y_train, num_classes=num_classes + ) + self.x_val, self.y_val = x_val, to_categorical(y_val, num_classes=num_classes) + + def get_parameters(self, config): + """Return the parameters of the current local model.""" + return self.model.get_weights() + + def fit(self, parameters, config): + """Implement distributed fit function for a given client.""" + self.model.set_weights(parameters) + + self.model.fit( + self.x_train, + self.y_train, + epochs=config["local_epochs"], + batch_size=config["batch_size"], + verbose=False, + ) + return self.model.get_weights(), len(self.x_train), {} + + def evaluate(self, parameters, config): + """Implement distributed evaluation for a given client.""" + self.model.set_weights(parameters) + loss, acc = self.model.evaluate(self.x_val, self.y_val, verbose=False) + return loss, len(self.x_val), {"accuracy": acc} + + +def generate_client_fn(partitions, model, num_classes): + """Generate the client function that creates the Flower Clients.""" + + def client_fn(cid: str) -> FlowerClient: + """Create a Flower client representing a single organization.""" + full_x_train_cid, full_y_train_cid = partitions[int(cid)] + + # Use 10% of the client's training data for validation + split_idx = math.floor(len(full_x_train_cid) * 0.9) + x_train_cid, y_train_cid = ( + full_x_train_cid[:split_idx], + full_y_train_cid[:split_idx], + ) + x_val_cid, y_val_cid = ( + full_x_train_cid[split_idx:], + full_y_train_cid[split_idx:], + ) + + return FlowerClient( + x_train_cid, y_train_cid, x_val_cid, y_val_cid, model, num_classes + ) + + return client_fn diff --git a/baselines/fedavgm/fedavgm/common.py b/baselines/fedavgm/fedavgm/common.py new file mode 100644 index 000000000000..0ce9d04dc544 --- /dev/null +++ b/baselines/fedavgm/fedavgm/common.py @@ -0,0 +1,494 @@ +# Copyright 2020 Adap GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Commonly used functions for generating partitioned datasets.""" + +# pylint: disable=invalid-name + + +from typing import List, Optional, Tuple, Union + +import numpy as np +from numpy.random import BitGenerator, Generator, SeedSequence + +XY = Tuple[np.ndarray, np.ndarray] +XYList = List[XY] +PartitionedDataset = Tuple[XYList, XYList] + + +def float_to_int(i: float) -> int: + """Return float as int but raise if decimal is dropped.""" + if not i.is_integer(): + raise Exception("Cast would drop decimals") + + return int(i) + + +def sort_by_label(x: np.ndarray, y: np.ndarray) -> XY: + """Sort by label. + + Assuming two labels and four examples the resulting label order would be 1,1,2,2 + """ + idx = np.argsort(y, axis=0).reshape((y.shape[0])) + return (x[idx], y[idx]) + + +def sort_by_label_repeating(x: np.ndarray, y: np.ndarray) -> XY: + """Sort by label in repeating groups. + + Assuming two labels and four examples the resulting label order would be 1,2,1,2. + + Create sorting index which is applied to by label sorted x, y + + .. code-block:: python + + # given: + y = [ + 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9 + ] + + # use: + idx = [ + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 1, 3, 5, 7, 9, 11, 13, 15, 17, 19 + ] + + # so that y[idx] becomes: + y = [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 + ] + """ + x, y = sort_by_label(x, y) + + num_example = x.shape[0] + num_class = np.unique(y).shape[0] + idx = ( + np.array(range(num_example), np.int64) + .reshape((num_class, num_example // num_class)) + .transpose() + .reshape(num_example) + ) + + return (x[idx], y[idx]) + + +def split_at_fraction(x: np.ndarray, y: np.ndarray, fraction: float) -> Tuple[XY, XY]: + """Split x, y at a certain fraction.""" + splitting_index = float_to_int(x.shape[0] * fraction) + # Take everything BEFORE splitting_index + x_0, y_0 = x[:splitting_index], y[:splitting_index] + # Take everything AFTER splitting_index + x_1, y_1 = x[splitting_index:], y[splitting_index:] + return (x_0, y_0), (x_1, y_1) + + +def shuffle(x: np.ndarray, y: np.ndarray) -> XY: + """Shuffle x and y.""" + idx = np.random.permutation(len(x)) + return x[idx], y[idx] + + +def partition(x: np.ndarray, y: np.ndarray, num_partitions: int) -> List[XY]: + """Return x, y as list of partitions.""" + return list(zip(np.split(x, num_partitions), np.split(y, num_partitions))) + + +def combine_partitions(xy_list_0: XYList, xy_list_1: XYList) -> XYList: + """Combine two lists of ndarray Tuples into one list.""" + return [ + (np.concatenate([x_0, x_1], axis=0), np.concatenate([y_0, y_1], axis=0)) + for (x_0, y_0), (x_1, y_1) in zip(xy_list_0, xy_list_1) + ] + + +def shift(x: np.ndarray, y: np.ndarray) -> XY: + """Shift x_1, y_1. + + so that the first half contains only labels 0 to 4 and the second half 5 to 9. + """ + x, y = sort_by_label(x, y) + + (x_0, y_0), (x_1, y_1) = split_at_fraction(x, y, fraction=0.5) + (x_0, y_0), (x_1, y_1) = shuffle(x_0, y_0), shuffle(x_1, y_1) + x, y = np.concatenate([x_0, x_1], axis=0), np.concatenate([y_0, y_1], axis=0) + return x, y + + +def create_partitions( + unpartitioned_dataset: XY, + iid_fraction: float, + num_partitions: int, +) -> XYList: + """Create partitioned version of a training or test set. + + Currently tested and supported are MNIST, FashionMNIST and CIFAR-10/100 + """ + x, y = unpartitioned_dataset + + x, y = shuffle(x, y) + x, y = sort_by_label_repeating(x, y) + + (x_0, y_0), (x_1, y_1) = split_at_fraction(x, y, fraction=iid_fraction) + + # Shift in second split of dataset the classes into two groups + x_1, y_1 = shift(x_1, y_1) + + xy_0_partitions = partition(x_0, y_0, num_partitions) + xy_1_partitions = partition(x_1, y_1, num_partitions) + + xy_partitions = combine_partitions(xy_0_partitions, xy_1_partitions) + + # Adjust x and y shape + return [adjust_xy_shape(xy) for xy in xy_partitions] + + +def create_partitioned_dataset( + keras_dataset: Tuple[XY, XY], + iid_fraction: float, + num_partitions: int, +) -> Tuple[PartitionedDataset, XY]: + """Create partitioned version of keras dataset. + + Currently tested and supported are MNIST, FashionMNIST and CIFAR-10/100 + """ + xy_train, xy_test = keras_dataset + + xy_train_partitions = create_partitions( + unpartitioned_dataset=xy_train, + iid_fraction=iid_fraction, + num_partitions=num_partitions, + ) + + xy_test_partitions = create_partitions( + unpartitioned_dataset=xy_test, + iid_fraction=iid_fraction, + num_partitions=num_partitions, + ) + + return (xy_train_partitions, xy_test_partitions), adjust_xy_shape(xy_test) + + +def log_distribution(xy_partitions: XYList) -> None: + """Print label distribution for list of paritions.""" + distro = [np.unique(y, return_counts=True) for _, y in xy_partitions] + for d in distro: + print(d) + + +def adjust_xy_shape(xy: XY) -> XY: + """Adjust shape of both x and y.""" + x, y = xy + if x.ndim == 3: + x = adjust_x_shape(x) + if y.ndim == 2: + y = adjust_y_shape(y) + return (x, y) + + +def adjust_x_shape(nda: np.ndarray) -> np.ndarray: + """Turn shape (x, y, z) into (x, y, z, 1).""" + nda_adjusted = np.reshape(nda, (nda.shape[0], nda.shape[1], nda.shape[2], 1)) + return nda_adjusted + + +def adjust_y_shape(nda: np.ndarray) -> np.ndarray: + """Turn shape (x, 1) into (x).""" + nda_adjusted = np.reshape(nda, (nda.shape[0])) + return nda_adjusted + + +def split_array_at_indices( + x: np.ndarray, split_idx: np.ndarray +) -> List[List[np.ndarray]]: + """Split the array `x`. + + into list of elements using starting indices from + `split_idx`. + + This function should be used with `unique_indices` from `np.unique()` after + sorting by label. + + Args: + x (np.ndarray): Original array of dimension (N,a,b,c,...) + split_idx (np.ndarray): 1-D array contaning increasing number of + indices to be used as partitions. Initial value must be zero. Last value + must be less than N. + + Returns + ------- + List[List[np.ndarray]]: List of list of samples. + """ + if split_idx.ndim != 1: + raise ValueError("Variable `split_idx` must be a 1-D numpy array.") + if split_idx.dtype != np.int64: + raise ValueError("Variable `split_idx` must be of type np.int64.") + if split_idx[0] != 0: + raise ValueError("First value of `split_idx` must be 0.") + if split_idx[-1] >= x.shape[0]: + raise ValueError( + """Last value in `split_idx` must be less than + the number of samples in `x`.""" + ) + if not np.all(split_idx[:-1] <= split_idx[1:]): + raise ValueError("Items in `split_idx` must be in increasing order.") + + num_splits: int = len(split_idx) + split_idx = np.append(split_idx, x.shape[0]) + + list_samples_split: List[List[np.ndarray]] = [[] for _ in range(num_splits)] + for j in range(num_splits): + tmp_x = x[split_idx[j] : split_idx[j + 1]] # noqa: E203 + for sample in tmp_x: + list_samples_split[j].append(sample) + + return list_samples_split + + +def exclude_classes_and_normalize( + distribution: np.ndarray, exclude_dims: List[bool], eps: float = 1e-5 +) -> np.ndarray: + """Excludes classes from a distribution. + + This function is particularly useful when sampling without replacement. + Classes for which no sample is available have their probabilities are set to 0. + Classes that had probabilities originally set to 0 are incremented with + `eps` to allow sampling from remaining items. + + Args: + distribution (np.array): Distribution being used. + exclude_dims (List[bool]): Dimensions to be excluded. + eps (float, optional): Small value to be addad to non-excluded dimensions. + Defaults to 1e-5. + + Returns + ------- + np.ndarray: Normalized distributions. + """ + if np.any(distribution < 0) or (not np.isclose(np.sum(distribution), 1.0)): + raise ValueError("distribution must sum to 1 and have only positive values.") + + if distribution.size != len(exclude_dims): + raise ValueError( + """Length of distribution must be equal + to the length `exclude_dims`.""" + ) + if eps < 0: + raise ValueError("""The value of `eps` must be positive and small.""") + + distribution[[not x for x in exclude_dims]] += eps + distribution[exclude_dims] = 0.0 + sum_rows = np.sum(distribution) + np.finfo(float).eps + distribution = distribution / sum_rows + + return distribution + + +def sample_without_replacement( + distribution: np.ndarray, + list_samples: List[List[np.ndarray]], + num_samples: int, + empty_classes: List[bool], +) -> Tuple[XY, List[bool]]: + """Sample from a list without replacement. + + using a given distribution. + + Args: + distribution (np.ndarray): Distribution used for sampling. + list_samples(List[List[np.ndarray]]): List of samples. + num_samples (int): Total number of items to be sampled. + empty_classes (List[bool]): List of booleans indicating which classes are empty. + This is useful to differentiate which classes should still be sampled. + + Returns + ------- + XY: Dataset contaning samples + List[bool]: empty_classes. + """ + if np.sum([len(x) for x in list_samples]) < num_samples: + raise ValueError( + """Number of samples in `list_samples` is less than `num_samples`""" + ) + + # Make sure empty classes are not sampled + # and solves for rare cases where + if not empty_classes: + empty_classes = len(distribution) * [False] + + distribution = exclude_classes_and_normalize( + distribution=distribution, exclude_dims=empty_classes + ) + + data: List[np.ndarray] = [] + target: List[np.ndarray] = [] + + for _ in range(num_samples): + sample_class = np.where(np.random.multinomial(1, distribution) == 1)[0][0] + sample: np.ndarray = list_samples[sample_class].pop() + + data.append(sample) + target.append(sample_class) + + # If last sample of the class was drawn, then set the + # probability density function (PDF) to zero for that class. + if len(list_samples[sample_class]) == 0: + empty_classes[sample_class] = True + # Be careful to distinguish between classes that had zero probability + # and classes that are now empty + distribution = exclude_classes_and_normalize( + distribution=distribution, exclude_dims=empty_classes + ) + data_array: np.ndarray = np.concatenate([data], axis=0) + target_array: np.ndarray = np.array(target, dtype=np.int64) + + return (data_array, target_array), empty_classes + + +def get_partitions_distributions(partitions: XYList) -> Tuple[np.ndarray, List[int]]: + """Evaluate the distribution over classes for a set of partitions. + + Args: + partitions (XYList): Input partitions + + Returns + ------- + np.ndarray: Distributions of size (num_partitions, num_classes) + """ + # Get largest available label + labels = set() + for _, y in partitions: + labels.update(set(y)) + list_labels = sorted(labels) + bin_edges = np.arange(len(list_labels) + 1) + + # Pre-allocate distributions + distributions = np.zeros((len(partitions), len(list_labels)), dtype=np.float32) + for idx, (_, _y) in enumerate(partitions): + hist, _ = np.histogram(_y, bin_edges) + distributions[idx] = hist / hist.sum() + + return distributions, list_labels + + +def create_lda_partitions( + dataset: XY, + dirichlet_dist: Optional[np.ndarray] = None, + num_partitions: int = 100, + concentration: Union[float, np.ndarray, List[float]] = 0.5, + accept_imbalanced: bool = False, + seed: Optional[Union[int, SeedSequence, BitGenerator, Generator]] = None, +) -> Tuple[XYList, np.ndarray]: + r"""Create imbalanced non-iid partitions using Latent Dirichlet Allocation (LDA). + + without resampling. + + Args: + dataset (XY): Dataset containing samples X and labels Y. + dirichlet_dist (numpy.ndarray, optional): previously generated distribution to + be used. This is useful when applying the same distribution for train and + validation sets. + num_partitions (int, optional): Number of partitions to be created. + Defaults to 100. + concentration (float, np.ndarray, List[float]): Dirichlet Concentration + (:math:`\\alpha`) parameter. Set to float('inf') to get uniform partitions. + An :math:`\\alpha \\to \\Inf` generates uniform distributions over classes. + An :math:`\\alpha \\to 0.0` generates one class per client. Defaults to 0.5. + accept_imbalanced (bool): Whether or not to accept imbalanced output classes. + Default False. + seed (None, int, SeedSequence, BitGenerator, Generator): + A seed to initialize the BitGenerator for generating the Dirichlet + distribution. This is defined in Numpy's official documentation as follows: + If None, then fresh, unpredictable entropy will be pulled from the OS. + One may also pass in a SeedSequence instance. + Additionally, when passed a BitGenerator, it will be wrapped by Generator. + If passed a Generator, it will be returned unaltered. + See official Numpy Documentation for further details. + + Returns + ------- + Tuple[XYList, numpy.ndarray]: List of XYList containing partitions + for each dataset and the dirichlet probability density functions. + """ + # pylint: disable=too-many-arguments,too-many-locals + + x, y = dataset + x, y = shuffle(x, y) + x, y = sort_by_label(x, y) + + if (x.shape[0] % num_partitions) and (not accept_imbalanced): + raise ValueError( + """Total number of samples must be a multiple of `num_partitions`. + If imbalanced classes are allowed, set + `accept_imbalanced=True`.""" + ) + + num_samples = num_partitions * [0] + for j in range(x.shape[0]): + num_samples[j % num_partitions] += 1 + + # Get number of classes and verify if they matching with + classes, start_indices = np.unique(y, return_index=True) + + # Make sure that concentration is np.array and + # check if concentration is appropriate + concentration = np.asarray(concentration) + + # Check if concentration is Inf, if so create uniform partitions + partitions: List[XY] = [(_, _) for _ in range(num_partitions)] + if float("inf") in concentration: + partitions = create_partitions( + unpartitioned_dataset=(x, y), + iid_fraction=1.0, + num_partitions=num_partitions, + ) + dirichlet_dist = get_partitions_distributions(partitions)[0] + + return partitions, dirichlet_dist + + if concentration.size == 1: + concentration = np.repeat(concentration, classes.size) + elif concentration.size != classes.size: # Sequence + raise ValueError( + f"The size of the provided concentration ({concentration.size}) ", + f"must be either 1 or equal number of classes {classes.size})", + ) + + # Split into list of list of samples per class + list_samples_per_class: List[List[np.ndarray]] = split_array_at_indices( + x, start_indices + ) + + if dirichlet_dist is None: + dirichlet_dist = np.random.default_rng(seed).dirichlet( + alpha=concentration, size=num_partitions + ) + + if dirichlet_dist.size != 0: + if dirichlet_dist.shape != (num_partitions, classes.size): + raise ValueError( + f"""The shape of the provided dirichlet distribution + ({dirichlet_dist.shape}) must match the provided number + of partitions and classes ({num_partitions},{classes.size})""" + ) + + # Assuming balanced distribution + empty_classes = classes.size * [False] + for partition_id in range(num_partitions): + partitions[partition_id], empty_classes = sample_without_replacement( + distribution=dirichlet_dist[partition_id].copy(), + list_samples=list_samples_per_class, + num_samples=num_samples[partition_id], + empty_classes=empty_classes, + ) + + return partitions, dirichlet_dist diff --git a/baselines/fedavgm/fedavgm/conf/base.yaml b/baselines/fedavgm/fedavgm/conf/base.yaml new file mode 100644 index 000000000000..3c2c281911a3 --- /dev/null +++ b/baselines/fedavgm/fedavgm/conf/base.yaml @@ -0,0 +1,24 @@ +--- +num_clients: 10 +num_rounds: 5 # original experiments (paper) uses 10000 +fraction_evaluate: 0 # fraction of clients usied during validation +num_cpus: 1 +num_gpus: 0 + +noniid: + concentration: 0.1 # concentrations used in the paper [100., 10., 1., 0.5, 0.2, 0.1, 0.05, 0.0] + +server: + momentum: 0.9 + learning_rate: 1.0 + reporting_fraction: 0.05 # values used in the paper 0.05, 0.1, 0.2 (not used for Figure 5), 0.4 + +client: + local_epochs: 1 # in the paper it is used 1 or 5 + batch_size: 64 # in the paper fixed at 64 + lr: 0.01 # client learning rate + +defaults: + - strategy: custom-fedavgm + - model: cnn + - dataset: cifar10 diff --git a/baselines/fedavgm/fedavgm/conf/dataset/cifar10.yaml b/baselines/fedavgm/fedavgm/conf/dataset/cifar10.yaml new file mode 100644 index 000000000000..4894ba5d675f --- /dev/null +++ b/baselines/fedavgm/fedavgm/conf/dataset/cifar10.yaml @@ -0,0 +1,4 @@ +--- +_target_: fedavgm.dataset.cifar10 +num_classes: 10 +input_shape: [32, 32, 3] \ No newline at end of file diff --git a/baselines/fedavgm/fedavgm/conf/dataset/fmnist.yaml b/baselines/fedavgm/fedavgm/conf/dataset/fmnist.yaml new file mode 100644 index 000000000000..2dfa07f1c60a --- /dev/null +++ b/baselines/fedavgm/fedavgm/conf/dataset/fmnist.yaml @@ -0,0 +1,4 @@ +--- +_target_: fedavgm.dataset.fmnist +num_classes: 10 +input_shape: [28, 28, 1] \ No newline at end of file diff --git a/baselines/fedavgm/fedavgm/conf/model/cnn.yaml b/baselines/fedavgm/fedavgm/conf/model/cnn.yaml new file mode 100644 index 000000000000..c25463693c7f --- /dev/null +++ b/baselines/fedavgm/fedavgm/conf/model/cnn.yaml @@ -0,0 +1,5 @@ +--- +_target_: fedavgm.models.cnn +input_shape: ${dataset.input_shape} +num_classes: ${dataset.num_classes} +learning_rate: ${client.lr} \ No newline at end of file diff --git a/baselines/fedavgm/fedavgm/conf/model/tf_example.yaml b/baselines/fedavgm/fedavgm/conf/model/tf_example.yaml new file mode 100644 index 000000000000..8c2a670ee978 --- /dev/null +++ b/baselines/fedavgm/fedavgm/conf/model/tf_example.yaml @@ -0,0 +1,5 @@ +--- +_target_: fedavgm.models.tf_example +input_shape: ${dataset.input_shape} +num_classes: ${dataset.num_classes} +learning_rate: ${client.lr} \ No newline at end of file diff --git a/baselines/fedavgm/fedavgm/conf/strategy/custom-fedavgm.yaml b/baselines/fedavgm/fedavgm/conf/strategy/custom-fedavgm.yaml new file mode 100644 index 000000000000..526c9714ed73 --- /dev/null +++ b/baselines/fedavgm/fedavgm/conf/strategy/custom-fedavgm.yaml @@ -0,0 +1,13 @@ +--- +_target_: fedavgm.strategy.CustomFedAvgM +min_available_clients: ${num_clients} +fraction_fit: ${server.reporting_fraction} +fraction_evaluate: ${fraction_evaluate} +server_learning_rate: ${server.learning_rate} +server_momentum: ${server.momentum} +on_fit_config_fn: + _target_: fedavgm.server.get_on_fit_config + config: ${client} +initial_parameters: + _target_: fedavgm.models.model_to_parameters + model: ${model} \ No newline at end of file diff --git a/baselines/fedavgm/fedavgm/conf/strategy/fedavg.yaml b/baselines/fedavgm/fedavgm/conf/strategy/fedavg.yaml new file mode 100644 index 000000000000..1b2cde85fe6c --- /dev/null +++ b/baselines/fedavgm/fedavgm/conf/strategy/fedavg.yaml @@ -0,0 +1,8 @@ +--- +_target_: flwr.server.strategy.FedAvg +min_available_clients: ${num_clients} +fraction_fit: ${server.reporting_fraction} +fraction_evaluate: ${fraction_evaluate} +on_fit_config_fn: + _target_: fedavgm.server.get_on_fit_config + config: ${client} \ No newline at end of file diff --git a/baselines/fedavgm/fedavgm/conf/strategy/fedavgm.yaml b/baselines/fedavgm/fedavgm/conf/strategy/fedavgm.yaml new file mode 100644 index 000000000000..ce88887c02ab --- /dev/null +++ b/baselines/fedavgm/fedavgm/conf/strategy/fedavgm.yaml @@ -0,0 +1,13 @@ +--- +_target_: flwr.server.strategy.FedAvgM +min_available_clients: ${num_clients} +fraction_fit: ${server.reporting_fraction} +fraction_evaluate: ${fraction_evaluate} +server_learning_rate: ${server.learning_rate} +server_momentum: ${server.momentum} +on_fit_config_fn: + _target_: fedavgm.server.get_on_fit_config + config: ${client} +initial_parameters: + _target_: fedavgm.models.model_to_parameters + model: ${model} \ No newline at end of file diff --git a/baselines/fedavgm/fedavgm/dataset.py b/baselines/fedavgm/fedavgm/dataset.py new file mode 100644 index 000000000000..939a42fda5ae --- /dev/null +++ b/baselines/fedavgm/fedavgm/dataset.py @@ -0,0 +1,57 @@ +"""Dataset utilities for federated learning.""" + +import numpy as np +from tensorflow import keras + +from fedavgm.common import create_lda_partitions + + +def cifar10(num_classes, input_shape): + """Prepare the CIFAR-10. + + This method considers CIFAR-10 for creating both train and test sets. The sets are + already normalized. + """ + print(f">>> [Dataset] Loading CIFAR-10. {num_classes} | {input_shape}.") + (x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data() + x_train = x_train.astype("float32") / 255 + x_test = x_test.astype("float32") / 255 + input_shape = x_train.shape[1:] + num_classes = len(np.unique(y_train)) + + return x_train, y_train, x_test, y_test, input_shape, num_classes + + +def fmnist(num_classes, input_shape): + """Prepare the FMNIST. + + This method considers FMNIST for creating both train and test sets. The sets are + already normalized. + """ + print(f">>> [Dataset] Loading FMNIST. {num_classes} | {input_shape}.") + (x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data() + x_train = x_train.astype("float32") / 255 + x_test = x_test.astype("float32") / 255 + input_shape = x_train.shape[1:] + num_classes = len(np.unique(y_train)) + + return x_train, y_train, x_test, y_test, input_shape, num_classes + + +def partition(x_train, y_train, num_clients, concentration): + """Create non-iid partitions. + + The partitions uses a LDA distribution based on concentration. + """ + print( + f">>> [Dataset] {num_clients} clients, non-iid concentration {concentration}..." + ) + dataset = [x_train, y_train] + partitions, _ = create_lda_partitions( + dataset, + num_partitions=num_clients, + # concentration=concentration * num_classes, + concentration=concentration, + seed=1234, + ) + return partitions diff --git a/baselines/fedavgm/fedavgm/dataset_preparation.py b/baselines/fedavgm/fedavgm/dataset_preparation.py new file mode 100644 index 000000000000..dab1967d8399 --- /dev/null +++ b/baselines/fedavgm/fedavgm/dataset_preparation.py @@ -0,0 +1 @@ +"""Require to download dataset or additional preparation.""" diff --git a/baselines/fedavgm/fedavgm/main.py b/baselines/fedavgm/fedavgm/main.py new file mode 100644 index 000000000000..915cad28f212 --- /dev/null +++ b/baselines/fedavgm/fedavgm/main.py @@ -0,0 +1,100 @@ +"""Create and connect the building blocks for your experiments; start the simulation. + +It includes processioning the dataset, instantiate strategy, specify how the global +model is going to be evaluated, etc. At the end, this script saves the results. +""" + +import pickle +from pathlib import Path + +import flwr as fl +import hydra +import numpy as np +from hydra.core.hydra_config import HydraConfig +from hydra.utils import instantiate +from omegaconf import DictConfig, OmegaConf + +from fedavgm.client import generate_client_fn +from fedavgm.dataset import partition +from fedavgm.server import get_evaluate_fn + + +# pylint: disable=too-many-locals +@hydra.main(config_path="conf", config_name="base", version_base=None) +def main(cfg: DictConfig) -> None: + """Run the baseline. + + Parameters + ---------- + cfg : DictConfig + An omegaconf object that stores the hydra config. + """ + np.random.seed(2020) + + # 1. Print parsed config + print(OmegaConf.to_yaml(cfg)) + + # 2. Prepare your dataset + x_train, y_train, x_test, y_test, input_shape, num_classes = instantiate( + cfg.dataset + ) + + partitions = partition(x_train, y_train, cfg.num_clients, cfg.noniid.concentration) + + print(f">>> [Model]: Num. Classes {num_classes} | Input shape: {input_shape}") + + # 3. Define your clients + client_fn = generate_client_fn(partitions, cfg.model, num_classes) + + # 4. Define your strategy + evaluate_fn = get_evaluate_fn( + instantiate(cfg.model), x_test, y_test, cfg.num_rounds, num_classes + ) + + strategy = instantiate(cfg.strategy, evaluate_fn=evaluate_fn) + + # 5. Start Simulation + history = fl.simulation.start_simulation( + client_fn=client_fn, + num_clients=cfg.num_clients, + config=fl.server.ServerConfig(num_rounds=cfg.num_rounds), + strategy=strategy, + client_resources={"num_cpus": cfg.num_cpus, "num_gpus": cfg.num_gpus}, + ) + + _, final_acc = history.metrics_centralized["accuracy"][-1] + + # 6. Save your results + save_path = HydraConfig.get().runtime.output_dir + + strategy_name = strategy.__class__.__name__ + dataset_type = "cifar10" if cfg.dataset.input_shape == [32, 32, 3] else "fmnist" + + def format_variable(x): + return f"{x!r}" if isinstance(x, bytes) else x + + file_suffix: str = ( + f"_{format_variable(strategy_name)}" + f"_{format_variable(dataset_type)}" + f"_clients={format_variable(cfg.num_clients)}" + f"_rounds={format_variable(cfg.num_rounds)}" + f"_C={format_variable(cfg.server.reporting_fraction)}" + f"_E={format_variable(cfg.client.local_epochs)}" + f"_alpha={format_variable(cfg.noniid.concentration)}" + f"_server-momentum={format_variable(cfg.server.momentum)}" + f"_client-lr={format_variable(cfg.client.lr)}" + f"_acc={format_variable(final_acc):.4f}" + ) + + filename = "results" + file_suffix + ".pkl" + + print(f">>> Saving {filename}...") + results_path = Path(save_path) / filename + results = {"history": history} + + with open(str(results_path), "wb") as hist_file: + pickle.dump(results, hist_file, protocol=pickle.HIGHEST_PROTOCOL) + + +if __name__ == "__main__": + main() diff --git a/baselines/fedavgm/fedavgm/models.py b/baselines/fedavgm/fedavgm/models.py new file mode 100644 index 000000000000..a151c4d9db76 --- /dev/null +++ b/baselines/fedavgm/fedavgm/models.py @@ -0,0 +1,121 @@ +"""CNN model architecture.""" + +from flwr.common import ndarrays_to_parameters +from keras.optimizers import SGD +from keras.regularizers import l2 +from tensorflow import keras +from tensorflow.nn import local_response_normalization # pylint: disable=import-error + + +def cnn(input_shape, num_classes, learning_rate): + """CNN Model from (McMahan et. al., 2017). + + Communication-efficient learning of deep networks from decentralized data + """ + input_shape = tuple(input_shape) + + weight_decay = 0.004 + model = keras.Sequential( + [ + keras.layers.Conv2D( + 64, + (5, 5), + padding="same", + activation="relu", + input_shape=input_shape, + ), + keras.layers.MaxPooling2D((3, 3), strides=(2, 2)), + keras.layers.BatchNormalization(), + keras.layers.Conv2D( + 64, + (5, 5), + padding="same", + activation="relu", + ), + keras.layers.BatchNormalization(), + keras.layers.MaxPooling2D((3, 3), strides=(2, 2)), + keras.layers.Flatten(), + keras.layers.Dense( + 384, activation="relu", kernel_regularizer=l2(weight_decay) + ), + keras.layers.Dense( + 192, activation="relu", kernel_regularizer=l2(weight_decay) + ), + keras.layers.Dense(num_classes, activation="softmax"), + ] + ) + optimizer = SGD(learning_rate=learning_rate) + model.compile( + loss="categorical_crossentropy", optimizer=optimizer, metrics=["accuracy"] + ) + + return model + + +def tf_example(input_shape, num_classes, learning_rate): + """CNN Model from TensorFlow v1.x example. + + This is the model referenced on the FedAvg paper. + + Reference: + https://web.archive.org/web/20170807002954/https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10/cifar10.py + """ + input_shape = tuple(input_shape) + + weight_decay = 0.004 + model = keras.Sequential( + [ + keras.layers.Conv2D( + 64, + (5, 5), + padding="same", + activation="relu", + input_shape=input_shape, + ), + keras.layers.MaxPooling2D((3, 3), strides=(2, 2), padding="same"), + keras.layers.Lambda( + local_response_normalization, + arguments={ + "depth_radius": 4, + "bias": 1.0, + "alpha": 0.001 / 9.0, + "beta": 0.75, + }, + ), + keras.layers.Conv2D( + 64, + (5, 5), + padding="same", + activation="relu", + ), + keras.layers.Lambda( + local_response_normalization, + arguments={ + "depth_radius": 4, + "bias": 1.0, + "alpha": 0.001 / 9.0, + "beta": 0.75, + }, + ), + keras.layers.MaxPooling2D((3, 3), strides=(2, 2), padding="same"), + keras.layers.Flatten(), + keras.layers.Dense( + 384, activation="relu", kernel_regularizer=l2(weight_decay) + ), + keras.layers.Dense( + 192, activation="relu", kernel_regularizer=l2(weight_decay) + ), + keras.layers.Dense(num_classes, activation="softmax"), + ] + ) + optimizer = SGD(learning_rate=learning_rate) + model.compile( + loss="categorical_crossentropy", optimizer=optimizer, metrics=["accuracy"] + ) + + return model + + +def model_to_parameters(model): + """Retrieve model weigths and convert to ndarrays.""" + return ndarrays_to_parameters(model.get_weights()) diff --git a/baselines/fedavgm/fedavgm/server.py b/baselines/fedavgm/fedavgm/server.py new file mode 100644 index 000000000000..c997c035f638 --- /dev/null +++ b/baselines/fedavgm/fedavgm/server.py @@ -0,0 +1,45 @@ +"""Define the Flower Server and function to instantiate it.""" + +from keras.utils import to_categorical +from omegaconf import DictConfig + + +def get_on_fit_config(config: DictConfig): + """Generate the function for config. + + The config dict is sent to the client fit() method. + """ + + def fit_config_fn(server_round: int): # pylint: disable=unused-argument + # option to use scheduling of learning rate based on round + # if server_round > 50: + # lr = config.lr / 10 + return { + "local_epochs": config.local_epochs, + "batch_size": config.batch_size, + } + + return fit_config_fn + + +def get_evaluate_fn(model, x_test, y_test, num_rounds, num_classes): + """Generate the function for server global model evaluation. + + The method evaluate_fn runs after global model aggregation. + """ + + def evaluate_fn( + server_round: int, parameters, config + ): # pylint: disable=unused-argument + if server_round == num_rounds: # evaluates global model just on the last round + # instantiate the model + model.set_weights(parameters) + + y_test_cat = to_categorical(y_test, num_classes=num_classes) + loss, accuracy = model.evaluate(x_test, y_test_cat, verbose=False) + + return loss, {"accuracy": accuracy} + + return None + + return evaluate_fn diff --git a/baselines/fedavgm/fedavgm/strategy.py b/baselines/fedavgm/fedavgm/strategy.py new file mode 100644 index 000000000000..cd0a27254fce --- /dev/null +++ b/baselines/fedavgm/fedavgm/strategy.py @@ -0,0 +1,201 @@ +"""Optionally define a custom strategy. + +Needed only when the strategy is not yet implemented in Flower or because you want to +extend or modify the functionality of an existing strategy. +""" + +from logging import WARNING +from typing import Callable, Dict, List, Optional, Tuple, Union + +from flwr.common import ( + FitRes, + MetricsAggregationFn, + NDArrays, + Parameters, + Scalar, + ndarrays_to_parameters, + parameters_to_ndarrays, +) +from flwr.common.logger import log +from flwr.server.client_manager import ClientManager +from flwr.server.client_proxy import ClientProxy +from flwr.server.strategy import FedAvg +from flwr.server.strategy.aggregate import aggregate + + +class CustomFedAvgM(FedAvg): + """Re-implmentation of FedAvgM. + + This implementation of FedAvgM diverges from original (Flwr v1.5.0) implementation. + Here, the re-implementation introduces the Nesterov Accelerated Gradient (NAG), + same as reported in the original FedAvgM paper: + + https://arxiv.org/pdf/1909.06335.pdf + """ + + def __init__( + self, + *, + fraction_fit: float = 1.0, + fraction_evaluate: float = 1.0, + min_fit_clients: int = 2, + min_evaluate_clients: int = 2, + min_available_clients: int = 2, + evaluate_fn: Optional[ + Callable[ + [int, NDArrays, Dict[str, Scalar]], + Optional[Tuple[float, Dict[str, Scalar]]], + ] + ] = None, + on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + accept_failures: bool = True, + initial_parameters: Parameters, + fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, + evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, + server_learning_rate: float = 1.0, + server_momentum: float = 0.9, + ) -> None: + """Federated Averaging with Momentum strategy. + + Implementation based on https://arxiv.org/pdf/1909.06335.pdf + + Parameters + ---------- + fraction_fit : float, optional + Fraction of clients used during training. Defaults to 0.1. + fraction_evaluate : float, optional + Fraction of clients used during validation. Defaults to 0.1. + min_fit_clients : int, optional + Minimum number of clients used during training. Defaults to 2. + min_evaluate_clients : int, optional + Minimum number of clients used during validation. Defaults to 2. + min_available_clients : int, optional + Minimum number of total clients in the system. Defaults to 2. + evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]], + Optional[Tuple[float, Dict[str, Scalar]]]]] + Optional function used for validation. Defaults to None. + on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional + Function used to configure training. Defaults to None. + on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional + Function used to configure validation. Defaults to None. + accept_failures : bool, optional + Whether or not accept rounds containing failures. Defaults to True. + initial_parameters : Parameters + Initial global model parameters. + server_learning_rate: float + Server-side learning rate used in server-side optimization. + Defaults to 1.0. + server_momentum: float + Server-side momentum factor used for FedAvgM. Defaults to 0.9. + """ + super().__init__( + fraction_fit=fraction_fit, + fraction_evaluate=fraction_evaluate, + min_fit_clients=min_fit_clients, + min_evaluate_clients=min_evaluate_clients, + min_available_clients=min_available_clients, + evaluate_fn=evaluate_fn, + on_fit_config_fn=on_fit_config_fn, + on_evaluate_config_fn=on_evaluate_config_fn, + accept_failures=accept_failures, + initial_parameters=initial_parameters, + fit_metrics_aggregation_fn=fit_metrics_aggregation_fn, + evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn, + ) + self.server_learning_rate = server_learning_rate + self.server_momentum = server_momentum + self.momentum_vector: Optional[NDArrays] = None + + def __repr__(self) -> str: + """Compute a string representation of the strategy.""" + rep = f"FedAvgM(accept_failures={self.accept_failures})" + return rep + + def initialize_parameters( + self, client_manager: ClientManager + ) -> Optional[Parameters]: + """Initialize global model parameters.""" + return self.initial_parameters + + def aggregate_fit( + self, + server_round: int, + results: List[Tuple[ClientProxy, FitRes]], + failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], + ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + """Aggregate fit results using weighted average.""" + if not results: + return None, {} + + # Do not aggregate if there are failures and failures are not accepted + if not self.accept_failures and failures: + return None, {} + + # Convert results + weights_results = [ + (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) + for _, fit_res in results + ] + + fedavg_result = aggregate(weights_results) # parameters_aggregated from FedAvg + + # original implementation follows convention described in + # https://pytorch.org/docs/stable/generated/torch.optim.SGD.html + + # do the check for self.initial_parameters being set + assert ( + self.initial_parameters is not None + ), "Initial parameters must be set for CustomFedAvgM strategy" + + # remember that updates are the opposite of gradients + pseudo_gradient: NDArrays = [ + x - y + for x, y in zip( + parameters_to_ndarrays(self.initial_parameters), fedavg_result + ) + ] + + if server_round > 1: + assert self.momentum_vector, "Momentum should have been created on round 1." + + self.momentum_vector = [ + self.server_momentum * v + w + for w, v in zip(pseudo_gradient, self.momentum_vector) + ] + else: # Round 1 + # Initialize server-side model + assert ( + self.initial_parameters is not None + ), "When using server-side optimization, model needs to be initialized." + # Initialize momentum vector + self.momentum_vector = pseudo_gradient + + # Applying Nesterov + pseudo_gradient = [ + g + self.server_momentum * v + for g, v in zip(pseudo_gradient, self.momentum_vector) + ] + + # Federated Averaging with Server Momentum + fedavgm_result = [ + w - self.server_learning_rate * v + for w, v in zip( + parameters_to_ndarrays(self.initial_parameters), pseudo_gradient + ) + ] + + # Update current weights + self.initial_parameters = ndarrays_to_parameters(fedavgm_result) + + parameters_aggregated = ndarrays_to_parameters(fedavgm_result) + + # Aggregate custom metrics if aggregation fn was provided + metrics_aggregated = {} + if self.fit_metrics_aggregation_fn: + fit_metrics = [(res.num_examples, res.metrics) for _, res in results] + metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics) + elif server_round == 1: # Only log this warning once + log(WARNING, "No fit_metrics_aggregation_fn provided") + + return parameters_aggregated, metrics_aggregated diff --git a/baselines/fedavgm/fedavgm/utils.py b/baselines/fedavgm/fedavgm/utils.py new file mode 100644 index 000000000000..42a3f372e6ad --- /dev/null +++ b/baselines/fedavgm/fedavgm/utils.py @@ -0,0 +1,61 @@ +"""Define any utility function. + +They are not directly relevant to the other (more FL specific) python modules. For +example, you may define here things like: loading a model from a checkpoint, saving +results, plotting. +""" + +import matplotlib.pyplot as plt +import numpy as np + +from fedavgm.dataset import cifar10, partition + +# pylint: disable=too-many-locals + + +def plot_concentrations_cifar10(): + """Create a plot with different concentrations for dataset using LDA.""" + x_train, y_train, x_test, y_test, _, num_classes = cifar10(10, (32, 32, 3)) + x = np.concatenate((x_train, x_test), axis=0) + y = np.concatenate((y_train, y_test), axis=0) + num_clients = 30 + + # Simulated different concentrations for partitioning + concentration_values = [np.inf, 100, 1, 0.1, 0.01, 1e-10] + color = plt.get_cmap("RdYlGn")(np.linspace(0.15, 0.85, num_classes)) + num_plots = len(concentration_values) + fig, axs = plt.subplots(1, num_plots, figsize=(15, 5), sharey=True) + + pos = axs[0].get_position() + pos.x0 += 0.1 + axs[0].set_position(pos) + + for i, concentration in enumerate(concentration_values): + partitions = partition(x, y, num_clients, concentration) + + for client in range(num_clients): + _, y_client = partitions[client] + lefts = [0] + axis = axs[i] + class_counts = np.bincount(y_client, minlength=num_classes) + np.sum(class_counts > 0) + + class_distribution = class_counts.astype(np.float16) / len(y_client) + + for idx, val in enumerate(class_distribution[:-1]): + lefts.append(lefts[idx] + val) + + axis.barh(client, class_distribution, left=lefts, color=color) + axis.set_xticks([]) + axis.set_yticks([]) + axis.set_xlabel("Class distribution") + axis.set_title(f"Concentration = {concentration}") + + fig.text(0, 0.5, "Client", va="center", rotation="vertical") + plt.tight_layout() + plt.savefig("../_static/concentration_cifar10_v2.png") + print(">>> Concentration plot created") + + +if __name__ == "__main__": + plot_concentrations_cifar10() diff --git a/baselines/fedavgm/pyproject.toml b/baselines/fedavgm/pyproject.toml new file mode 100644 index 000000000000..d222baa65b0e --- /dev/null +++ b/baselines/fedavgm/pyproject.toml @@ -0,0 +1,139 @@ +[build-system] +requires = ["poetry-core>=1.4.0"] +build-backend = "poetry.masonry.api" + +[tool.poetry] +name = "fedavgm" +version = "1.0.0" +description = "FedAvgM: Measuring the effects of non-identical data distribution for federated visual classification" +license = "Apache-2.0" +authors = ["Gustavo Bertoli"] +readme = "README.md" +homepage = "https://flower.ai" +repository = "https://github.com/adap/flower" +documentation = "https://flower.ai" +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Operating System :: MacOS :: MacOS X", + "Operating System :: POSIX :: Linux", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: Implementation :: CPython", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Scientific/Engineering :: Mathematics", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + "Typing :: Typed", +] + +[tool.poetry.dependencies] +python = ">=3.9, <3.12.0" # changed! original baseline template uses >= 3.8.15 +flwr = { extras = ["simulation"], version = "1.5.0" } +hydra-core = "1.3.2" # don't change this +cython = "^3.0.0" +tensorflow = "2.11.1" +numpy = "1.25.2" +matplotlib = "^3.7.2" + +[tool.poetry.dev-dependencies] +isort = "==5.13.2" +black = "==24.2.0" +docformatter = "==1.7.5" +mypy = "==1.4.1" +pylint = "==2.8.2" +flake8 = "==3.9.2" +pytest = "==6.2.4" +pytest-watch = "==4.2.0" +ruff = "==0.0.272" +types-requests = "==2.27.7" +virtualenv = "==20.21.0" + +[tool.isort] +line_length = 88 +indent = " " +multi_line_output = 3 +include_trailing_comma = true +force_grid_wrap = 0 +use_parentheses = true + +[tool.black] +line-length = 88 +target-version = ["py38", "py39", "py310", "py311"] + +[tool.pytest.ini_options] +minversion = "6.2" +addopts = "-qq" +testpaths = [ + "flwr_baselines", +] + +[tool.mypy] +ignore_missing_imports = true +strict = false +plugins = "numpy.typing.mypy_plugin" + +[tool.pylint."MESSAGES CONTROL"] +disable = "bad-continuation,duplicate-code,too-few-public-methods,useless-import-alias" +good-names = "i,j,k,_,x,y,X,Y" +signature-mutators = "hydra.main.main" + +[[tool.mypy.overrides]] +module = [ + "importlib.metadata.*", + "importlib_metadata.*", +] +follow_imports = "skip" +follow_imports_for_stubs = true +disallow_untyped_calls = false + +[[tool.mypy.overrides]] +module = "torch.*" +follow_imports = "skip" +follow_imports_for_stubs = true + +[tool.docformatter] +wrap-summaries = 88 +wrap-descriptions = 88 + +[tool.ruff] +target-version = "py38" +line-length = 88 +select = ["D", "E", "F", "W", "B", "ISC", "C4"] +fixable = ["D", "E", "F", "W", "B", "ISC", "C4"] +ignore = ["B024", "B027"] +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".hg", + ".mypy_cache", + ".nox", + ".pants.d", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "venv", + "proto", +] + +[tool.ruff.pydocstyle] +convention = "numpy" diff --git a/baselines/fedbn/README.md b/baselines/fedbn/README.md index 4b271bd49851..37388894ff9d 100644 --- a/baselines/fedbn/README.md +++ b/baselines/fedbn/README.md @@ -34,37 +34,37 @@ dataset: [MNIST, MNIST-M, SVHN, USPS, SynthDigits] **Model:** A six-layer CNN with 14,219,210 parameters following the structure described in appendix D.2. -**Dataset:** This baseline makes use of the pre-processed partitions created and open source by the authors of the FedBN paper. You can read more about how those were created [here](https://github.com/med-air/FedBN). Follow the steps below in the `Environment Setup` section to download them. +**Dataset:** This baseline makes use of the pre-processed partitions created and open source by the authors of the FedBN paper. You can read more about how those were created [here](https://github.com/med-air/FedBN). Follow the steps below in the `Environment Setup` section to download them. A more detailed explanation of the datasets is given in the following table. -| | MNIST | MNIST-M | SVHN | USPS | SynthDigits | -|--- |--- |--- |--- |--- |--- | -| data type| handwritten digits| MNIST modification randomly colored with colored patches| Street view house numbers | handwritten digits from envelopes by the U.S. Postal Service | Syntehtic digits Windows TM font varying the orientation, blur and stroke colors | -| color | greyscale | RGB | RGB | greyscale | RGB | -| pixelsize | 28x28 | 28 x 28 | 32 x32 | 16 x16 | 32 x32 | -| labels | 0-9 | 0-9 | 1-10 | 0-9 | 1-10 | -| number of trainset | 60.000 | 60.000 | 73.257 | 9,298 | 50.000 | -| number of testset| 10.000 | 10.000 | 26.032 | - | - | -| image shape | (28,28) | (28,28,3) | (32,32,3) | (16,16) | (32,32,3) | +| | MNIST | MNIST-M | SVHN | USPS | SynthDigits | +| ------------------ | ------------------ | -------------------------------------------------------- | ------------------------- | ------------------------------------------------------------ | -------------------------------------------------------------------------------- | +| data type | handwritten digits | MNIST modification randomly colored with colored patches | Street view house numbers | handwritten digits from envelopes by the U.S. Postal Service | Synthetic digits Windows TM font varying the orientation, blur and stroke colors | +| color | greyscale | RGB | RGB | greyscale | RGB | +| pixelsize | 28x28 | 28 x 28 | 32 x32 | 16 x16 | 32 x32 | +| labels | 0-9 | 0-9 | 1-10 | 0-9 | 1-10 | +| number of trainset | 60.000 | 60.000 | 73.257 | 9,298 | 50.000 | +| number of testset | 10.000 | 10.000 | 26.032 | - | - | +| image shape | (28,28) | (28,28,3) | (32,32,3) | (16,16) | (32,32,3) | **Training Hyperparameters:** By default (i.e. if you don't override anything in the config) these main hyperparameters used are shown in the table below. For a complete list of hyperparameters, please refer to the config files in `fedbn/conf`. -| Description | Value | -| ----------- | ----- | -| rounds | 10 | -| num_clients | 5 | -| strategy_fraction_fit | 1.0 | -| strategy.fraction_evaluate | 0.0 | -| training samples per client| 743 | -| client.l_r | 10E-2 | -| local epochs | 1 | -| loss | cross entropy loss | -| optimizer | SGD | -| client_resources.num_cpu | 2 | -| client_resources.num_gpus | 0.0 | +| Description | Value | +| --------------------------- | ------------------ | +| rounds | 10 | +| num_clients | 5 | +| strategy_fraction_fit | 1.0 | +| strategy.fraction_evaluate | 0.0 | +| training samples per client | 743 | +| client.l_r | 10E-2 | +| local epochs | 1 | +| loss | cross entropy loss | +| optimizer | SGD | +| client_resources.num_cpu | 2 | +| client_resources.num_gpus | 0.0 | ## Environment Setup @@ -93,7 +93,7 @@ cd data .. ## Running the Experiments -First, activate your environment via `poetry shell`. The commands below show how to run the experiments and modify some of its key hyperparameters via the cli. Each time you run an experiment, the log and results will be stored inside `outputs// {% endif %} diff --git a/doc/source/_templates/sidebar/versioning.html b/doc/source/_templates/sidebar/versioning.html index de607b55373a..74f1cd8febb7 100644 --- a/doc/source/_templates/sidebar/versioning.html +++ b/doc/source/_templates/sidebar/versioning.html @@ -57,14 +57,40 @@ } -
- + + +
+
diff --git a/doc/source/conf.py b/doc/source/conf.py index 8077d26aa6ae..88cb5c05b1d8 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -14,6 +14,7 @@ # ============================================================================== +import datetime import os import sys from git import Repo @@ -81,11 +82,11 @@ # -- Project information ----------------------------------------------------- project = "Flower" -copyright = "2022 Flower Labs GmbH" +copyright = f"{datetime.date.today().year} Flower Labs GmbH" author = "The Flower Authors" # The full version, including alpha/beta/rc tags -release = "1.7.0" +release = "1.8.0" # -- General configuration --------------------------------------------------- @@ -95,6 +96,7 @@ extensions = [ "sphinx.ext.napoleon", "sphinx.ext.autodoc", + "sphinx.ext.autosummary", "sphinx.ext.mathjax", "sphinx.ext.viewcode", "sphinx.ext.graphviz", @@ -108,6 +110,44 @@ "nbsphinx", ] +# Generate .rst files +autosummary_generate = True + +# Document ONLY the objects from __all__ (present in __init__ files). +# It will be done recursively starting from flwr.__init__ +# Starting point is controlled in the index.rst file. +autosummary_ignore_module_all = False + +# Each class and function docs start with the path to it +# Make the flwr_datasets.federated_dataset.FederatedDataset appear as FederatedDataset +# The full name is still at the top of the page +add_module_names = False + + +def find_test_modules(package_path): + """Go through the python files and exclude every *_test.py file.""" + full_path_modules = [] + for root, dirs, files in os.walk(package_path): + for file in files: + if file.endswith("_test.py"): + # Construct the module path relative to the package directory + full_path = os.path.join(root, file) + relative_path = os.path.relpath(full_path, package_path) + # Convert file path to dotted module path + module_path = os.path.splitext(relative_path)[0].replace(os.sep, ".") + full_path_modules.append(module_path) + modules = [] + for full_path_module in full_path_modules: + parts = full_path_module.split(".") + for i in range(len(parts)): + modules.append(".".join(parts[i:])) + return modules + + +# Stop from documenting the *_test.py files. +# That's the only way to do that in autosummary (make the modules as mock_imports). +autodoc_mock_imports = find_test_modules(os.path.abspath("../../src/py/flwr")) + # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] @@ -135,6 +175,7 @@ "writing-documentation": "contributor-how-to-write-documentation.html", "apiref-binaries": "ref-api-cli.html", "fedbn-example-pytorch-from-centralized-to-federated": "example-fedbn-pytorch-from-centralized-to-federated.html", + "how-to-use-built-in-middleware-layers": "how-to-use-built-in-mods.html", # Restructuring: tutorials "tutorial/Flower-0-What-is-FL": "tutorial-series-what-is-federated-learning.html", "tutorial/Flower-1-Intro-to-FL-PyTorch": "tutorial-series-get-started-with-flower-pytorch.html", @@ -173,7 +214,8 @@ "evaluation": "explanation-federated-evaluation.html", "differential-privacy-wrappers": "explanation-differential-privacy.html", # Restructuring: references - "apiref-flwr": "ref-api-flwr.html", + "apiref-flwr": "ref-api/flwr.html", + "ref-api-flwr": "ref-api/flwr.html", "apiref-cli": "ref-api-cli.html", "examples": "ref-example-projects.html", "telemetry": "ref-telemetry.html", @@ -209,7 +251,7 @@ html_title = f"Flower Framework" html_logo = "_static/flower-logo.png" html_favicon = "_static/favicon.ico" -html_baseurl = "https://flower.dev/docs/framework/" +html_baseurl = "https://flower.ai/docs/framework/" html_theme_options = { # diff --git a/doc/source/contributor-explanation-architecture.rst b/doc/source/contributor-explanation-architecture.rst index 0e2ea1f6e66b..a20a84313118 100644 --- a/doc/source/contributor-explanation-architecture.rst +++ b/doc/source/contributor-explanation-architecture.rst @@ -4,7 +4,7 @@ Flower Architecture Edge Client Engine ------------------ -`Flower `_ core framework architecture with Edge Client Engine +`Flower `_ core framework architecture with Edge Client Engine .. figure:: _static/flower-architecture-ECE.png :width: 80 % @@ -12,7 +12,7 @@ Edge Client Engine Virtual Client Engine --------------------- -`Flower `_ core framework architecture with Virtual Client Engine +`Flower `_ core framework architecture with Virtual Client Engine .. figure:: _static/flower-architecture-VCE.png :width: 80 % @@ -20,7 +20,7 @@ Virtual Client Engine Virtual Client Engine and Edge Client Engine in the same workload ----------------------------------------------------------------- -`Flower `_ core framework architecture with both Virtual Client Engine and Edge Client Engine +`Flower `_ core framework architecture with both Virtual Client Engine and Edge Client Engine .. figure:: _static/flower-architecture.drawio.png :width: 80 % diff --git a/doc/source/contributor-how-to-build-docker-images.rst b/doc/source/contributor-how-to-build-docker-images.rst new file mode 100644 index 000000000000..5dead265bee2 --- /dev/null +++ b/doc/source/contributor-how-to-build-docker-images.rst @@ -0,0 +1,135 @@ +How to build Docker Flower images locally +========================================= + +Flower provides pre-made docker images on `Docker Hub `_ +that include all necessary dependencies for running the server. You can also build your own custom +docker images from scratch with a different version of Python or Ubuntu if that is what you need. +In this guide, we will explain what images exist and how to build them locally. + +Before we can start, we need to meet a few prerequisites in our local development environment. + +#. Clone the flower repository. + + .. code-block:: bash + + $ git clone https://github.com/adap/flower.git && cd flower + +#. Verify the Docker daemon is running. + + Please follow the first section on + :doc:`Run Flower using Docker ` + which covers this step in more detail. + +Currently, Flower provides two images, a base image and a server image. There will also be a client +image soon. The base image, as the name suggests, contains basic dependencies that both the server +and the client need. This includes system dependencies, Python and Python tools. The server image is +based on the base image, but it additionally installs the Flower server using ``pip``. + +The build instructions that assemble the images are located in the respective Dockerfiles. You +can find them in the subdirectories of ``src/docker``. + +Both, base and server image are configured via build arguments. Through build arguments, we can make +our build more flexible. For example, in the base image, we can specify the version of Python to +install using the ``PYTHON_VERSION`` build argument. Some of the build arguments have default +values, others must be specified when building the image. All available build arguments for each +image are listed in one of the tables below. + +Building the base image +----------------------- + +.. list-table:: + :widths: 25 45 15 15 + :header-rows: 1 + + * - Build argument + - Description + - Required + - Example + * - ``PYTHON_VERSION`` + - Version of ``python`` to be installed. + - Yes + - ``3.11`` + * - ``PIP_VERSION`` + - Version of ``pip`` to be installed. + - Yes + - ``23.0.1`` + * - ``SETUPTOOLS_VERSION`` + - Version of ``setuptools`` to be installed. + - Yes + - ``69.0.2`` + * - ``UBUNTU_VERSION`` + - Version of the official Ubuntu Docker image. + - Defaults to ``22.04``. + - + +The following example creates a base image with Python 3.11.0, pip 23.0.1 and setuptools 69.0.2: + +.. code-block:: bash + + $ cd src/docker/base/ + $ docker build \ + --build-arg PYTHON_VERSION=3.11.0 \ + --build-arg PIP_VERSION=23.0.1 \ + --build-arg SETUPTOOLS_VERSION=69.0.2 \ + -t flwr_base:0.1.0 . + +The name of image is ``flwr_base`` and the tag ``0.1.0``. Remember that the build arguments as well +as the name and tag can be adapted to your needs. These values serve as examples only. + +Building the server image +------------------------- + +.. list-table:: + :widths: 25 45 15 15 + :header-rows: 1 + + * - Build argument + - Description + - Required + - Example + * - ``BASE_REPOSITORY`` + - The repository name of the base image. + - Defaults to ``flwr/server``. + - + * - ``BASE_IMAGE_TAG`` + - The image tag of the base image. + - Defaults to ``py3.11-ubuntu22.04``. + - + * - ``FLWR_VERSION`` + - Version of Flower to be installed. + - Yes + - ``1.7.0`` + +The following example creates a server image with the official Flower base image py3.11-ubuntu22.04 +and Flower 1.7.0: + +.. code-block:: bash + + $ cd src/docker/server/ + $ docker build \ + --build-arg BASE_IMAGE_TAG=py3.11-ubuntu22.04 \ + --build-arg FLWR_VERSION=1.7.0 \ + -t flwr_server:0.1.0 . + +The name of image is ``flwr_server`` and the tag ``0.1.0``. Remember that the build arguments as well +as the name and tag can be adapted to your needs. These values serve as examples only. + +If you want to use your own base image instead of the official Flower base image, all you need to do +is set the ``BASE_REPOSITORY`` and ``BASE_IMAGE_TAG`` build arguments. The value of +``BASE_REPOSITORY`` must match the name of your image and the value of ``BASE_IMAGE_TAG`` must match +the tag of your image. + +.. code-block:: bash + + $ cd src/docker/server/ + $ docker build \ + --build-arg BASE_REPOSITORY=flwr_base \ + --build-arg BASE_IMAGE_TAG=0.1.0 \ + --build-arg FLWR_VERSION=1.7.0 \ + -t flwr_server:0.1.0 . + +After creating the image, we can test whether the image is working: + +.. code-block:: bash + + $ docker run --rm flwr_server:0.1.0 --help diff --git a/doc/source/contributor-how-to-contribute-translations.rst b/doc/source/contributor-how-to-contribute-translations.rst index d97a2cb8c64f..ba59901cf1c4 100644 --- a/doc/source/contributor-how-to-contribute-translations.rst +++ b/doc/source/contributor-how-to-contribute-translations.rst @@ -2,13 +2,13 @@ Contribute translations ======================= Since `Flower 1.5 -`_ we +`_ we have introduced translations to our doc pages, but, as you might have noticed, the translations are often imperfect. If you speak languages other than English, you might be able to help us in our effort to make Federated Learning accessible to as many people as possible by contributing to those translations! This might also be a great opportunity for those wanting to become open source -contributors with little prerequistes. +contributors with little prerequisites. Our translation project is publicly available over on `Weblate `_, this where most @@ -44,7 +44,7 @@ This is what the interface looks like: .. image:: _static/weblate_interface.png -You input your translation in the textbox at the top and then, once you are +You input your translation in the text box at the top and then, once you are happy with it, you either press ``Save and continue`` (to save the translation and go to the next untranslated string), ``Save and stay`` (to save the translation and stay on the same page), ``Suggest`` (to add your translation to @@ -67,5 +67,5 @@ Add new languages ----------------- If you want to add a new language, you will first have to contact us, either on -`Slack `_, or by opening an issue on our `GitHub +`Slack `_, or by opening an issue on our `GitHub repo `_. diff --git a/doc/source/contributor-how-create-new-messages.rst b/doc/source/contributor-how-to-create-new-messages.rst similarity index 95% rename from doc/source/contributor-how-create-new-messages.rst rename to doc/source/contributor-how-to-create-new-messages.rst index 24fa5f573158..3f1849bdce47 100644 --- a/doc/source/contributor-how-create-new-messages.rst +++ b/doc/source/contributor-how-to-create-new-messages.rst @@ -29,8 +29,8 @@ Let's now see what we need to implement in order to get this simple function bet Message Types for Protocol Buffers ---------------------------------- -The first thing we need to do is to define a message type for the RPC system in :code:`transport.proto`. -Note that we have to do it for both the request and response messages. For more details on the syntax of proto3, please see the `official documentation `_. +The first thing we need to do is to define a message type for the RPC system in :code:`transport.proto`. +Note that we have to do it for both the request and response messages. For more details on the syntax of proto3, please see the `official documentation `_. Within the :code:`ServerMessage` block: @@ -75,7 +75,7 @@ Once that is done, we will compile the file with: $ python -m flwr_tool.protoc -If it compiles succesfully, you should see the following message: +If it compiles successfully, you should see the following message: .. code-block:: shell diff --git a/doc/source/contributor-how-to-develop-in-vscode-dev-containers.rst b/doc/source/contributor-how-to-develop-in-vscode-dev-containers.rst index 19d46c5753c6..c861457b6edc 100644 --- a/doc/source/contributor-how-to-develop-in-vscode-dev-containers.rst +++ b/doc/source/contributor-how-to-develop-in-vscode-dev-containers.rst @@ -8,7 +8,7 @@ When working on the Flower framework we want to ensure that all contributors use Workspace files are mounted from the local file system or copied or cloned into the container. Extensions are installed and run inside the container, where they have full access to the tools, platform, and file system. This means that you can seamlessly switch your entire development environment just by connecting to a different container. -Source: `Official VSCode documentation `_ +Source: `Official VSCode documentation `_ Getting started @@ -20,5 +20,5 @@ Now you should be good to go. When starting VSCode, it will ask you to run in th In some cases your setup might be more involved. For those cases consult the following sources: -* `Developing inside a Container `_ -* `Remote development in Containers `_ +* `Developing inside a Container `_ +* `Remote development in Containers `_ diff --git a/doc/source/contributor-how-to-install-development-versions.rst b/doc/source/contributor-how-to-install-development-versions.rst index 243f4ef97e8e..558ec7f8ec46 100644 --- a/doc/source/contributor-how-to-install-development-versions.rst +++ b/doc/source/contributor-how-to-install-development-versions.rst @@ -19,8 +19,8 @@ Install ``flwr`` from a local copy of the Flower source code via ``pyproject.tom Install ``flwr`` from a local wheel file via ``pyproject.toml``: -- ``flwr = { path = "../../dist/flwr-1.0.0-py3-none-any.whl" }`` (without extras) -- ``flwr = { path = "../../dist/flwr-1.0.0-py3-none-any.whl", extras = ["simulation"] }`` (with extras) +- ``flwr = { path = "../../dist/flwr-1.8.0-py3-none-any.whl" }`` (without extras) +- ``flwr = { path = "../../dist/flwr-1.8.0-py3-none-any.whl", extras = ["simulation"] }`` (with extras) Please refer to the Poetry documentation for further details: `Poetry Dependency Specification `_ @@ -59,5 +59,5 @@ Open a development version of the same notebook from branch `branch-name` by cha Install a `whl` on Google Colab: 1. In the vertical icon grid on the left hand side, select ``Files`` > ``Upload to session storage`` -2. Upload the whl (e.g., ``flwr-1.7.0-py3-none-any.whl``) -3. Change ``!pip install -q 'flwr[simulation]' torch torchvision matplotlib`` to ``!pip install -q 'flwr-1.7.0-py3-none-any.whl[simulation]' torch torchvision matplotlib`` +2. Upload the whl (e.g., ``flwr-1.8.0-py3-none-any.whl``) +3. Change ``!pip install -q 'flwr[simulation]' torch torchvision matplotlib`` to ``!pip install -q 'flwr-1.8.0-py3-none-any.whl[simulation]' torch torchvision matplotlib`` diff --git a/doc/source/contributor-how-to-release-flower.rst b/doc/source/contributor-how-to-release-flower.rst index 2eef165c0ed0..4853d87bc4c1 100644 --- a/doc/source/contributor-how-to-release-flower.rst +++ b/doc/source/contributor-how-to-release-flower.rst @@ -3,23 +3,15 @@ Release Flower This document describes the current release process. It may or may not change in the future. -Before the release ------------------- - -Update the changelog (``changelog.md``) with all relevant changes that happened after the last release. If the last release was tagged ``v1.2.0``, you can use the following URL to see all commits that got merged into ``main`` since then: - -`GitHub: Compare v1.2.0...main `_ - -Thank the authors who contributed since the last release. This can be done by running the ``./dev/add-shortlog.sh `` convenience script (it can be ran multiple times and will update the names in the list if new contributors were added in the meantime). - During the release ------------------ The version number of a release is stated in ``pyproject.toml``. To release a new version of Flower, the following things need to happen (in that order): -1. Update the ``changelog.md`` section header ``Unreleased`` to contain the version number and date for the release you are building. Create a pull request with the change. -2. Tag the release commit with the version number as soon as the PR is merged: ``git tag v0.12.3``, then ``git push --tags``. This will create a draft release on GitHub containing the correct artifacts and the relevant part of the changelog. -3. Check the draft release on GitHub, and if everything is good, publish it. +1. Run ``python3 src/py/flwr_tool/update_changelog.py `` in order to add every new change to the changelog (feel free to make manual changes to the changelog afterwards until it looks good). +2. Once the changelog has been updated with all the changes, run ``./dev/prepare-release-changelog.sh v``, where ```` is the version stated in ``pyproject.toml`` (notice the ``v`` added before it). This will replace the ``Unreleased`` header of the changelog by the version and current date, and it will add a thanking message for the contributors. Open a pull request with those changes. +3. Once the pull request is merged, tag the release commit with the version number as soon as the PR is merged: ``git tag v`` (notice the ``v`` added before the version number), then ``git push --tags``. This will create a draft release on GitHub containing the correct artifacts and the relevant part of the changelog. +4. Check the draft release on GitHub, and if everything is good, publish it. After the release ----------------- @@ -30,7 +22,7 @@ Create a pull request which contains the following changes: 2. Update all files which contain the current version number if necessary. 3. Add a new ``Unreleased`` section in ``changelog.md``. -Merge the pull request on the same day (i.e., before a new nighly release gets published to PyPI). +Merge the pull request on the same day (i.e., before a new nightly release gets published to PyPI). Publishing a pre-release ------------------------ @@ -38,11 +30,11 @@ Publishing a pre-release Pre-release naming ~~~~~~~~~~~~~~~~~~ -PyPI supports pre-releases (alpha, beta, release candiate). Pre-releases MUST use one of the following naming patterns: +PyPI supports pre-releases (alpha, beta, release candidate). Pre-releases MUST use one of the following naming patterns: - Alpha: ``MAJOR.MINOR.PATCHaN`` - Beta: ``MAJOR.MINOR.PATCHbN`` -- Release candiate (RC): ``MAJOR.MINOR.PATCHrcN`` +- Release candidate (RC): ``MAJOR.MINOR.PATCHrcN`` Examples include: diff --git a/doc/source/contributor-ref-good-first-contributions.rst b/doc/source/contributor-ref-good-first-contributions.rst index 523a4679c6ef..2b8ce88413f5 100644 --- a/doc/source/contributor-ref-good-first-contributions.rst +++ b/doc/source/contributor-ref-good-first-contributions.rst @@ -14,7 +14,7 @@ Until the Flower core library matures it will be easier to get PR's accepted if they only touch non-core areas of the codebase. Good candidates to get started are: -- Documentation: What's missing? What could be expressed more clearly? +- Documentation: What's missing? What could be expressed more clearly? - Baselines: See below. - Examples: See below. @@ -22,11 +22,11 @@ are: Request for Flower Baselines ---------------------------- -If you are not familiar with Flower Baselines, you should probably check-out our `contributing guide for baselines `_. +If you are not familiar with Flower Baselines, you should probably check-out our `contributing guide for baselines `_. -You should then check out the open +You should then check out the open `issues `_ for baseline requests. -If you find a baseline that you'd like to work on and that has no assignes, feel free to assign it to yourself and start working on it! +If you find a baseline that you'd like to work on and that has no assignees, feel free to assign it to yourself and start working on it! Otherwise, if you don't find a baseline you'd like to work on, be sure to open a new issue with the baseline request template! diff --git a/doc/source/contributor-tutorial-contribute-on-github.rst b/doc/source/contributor-tutorial-contribute-on-github.rst index 9aeb8229b412..6da81ce73662 100644 --- a/doc/source/contributor-tutorial-contribute-on-github.rst +++ b/doc/source/contributor-tutorial-contribute-on-github.rst @@ -3,9 +3,7 @@ Contribute on GitHub This guide is for people who want to get involved with Flower, but who are not used to contributing to GitHub projects. -If you're familiar with how contributing on GitHub works, you can directly checkout our -`getting started guide for contributors `_ -and examples of `good first contributions `_. +If you're familiar with how contributing on GitHub works, you can directly checkout our :doc:`getting started guide for contributors `. Setting up the repository @@ -13,21 +11,21 @@ Setting up the repository 1. **Create a GitHub account and setup Git** Git is a distributed version control tool. This allows for an entire codebase's history to be stored and every developer's machine. - It is a software that will need to be installed on your local machine, you can follow this `guide `_ to set it up. + It is a software that will need to be installed on your local machine, you can follow this `guide `_ to set it up. GitHub, itself, is a code hosting platform for version control and collaboration. It allows for everyone to collaborate and work from anywhere on remote repositories. - If you haven't already, you will need to create an account on `GitHub `_. + If you haven't already, you will need to create an account on `GitHub `_. - The idea behind the generic Git and GitHub workflow boils down to this: + The idea behind the generic Git and GitHub workflow boils down to this: you download code from a remote repository on GitHub, make changes locally and keep track of them using Git and then you upload your new history back to GitHub. 2. **Forking the Flower repository** - A fork is a personal copy of a GitHub repository. To create one for Flower, you must navigate to https://github.com/adap/flower (while connected to your GitHub account) + A fork is a personal copy of a GitHub repository. To create one for Flower, you must navigate to ``_ (while connected to your GitHub account) and click the ``Fork`` button situated on the top right of the page. .. image:: _static/fork_button.png - + You can change the name if you want, but this is not necessary as this version of Flower will be yours and will sit inside your own account (i.e., in your own list of repositories). Once created, you should see on the top left corner that you are looking at your own version of Flower. @@ -35,18 +33,18 @@ Setting up the repository 3. **Cloning your forked repository** The next step is to download the forked repository on your machine to be able to make changes to it. - On your forked repository page, you should first click on the ``Code`` button on the right, + On your forked repository page, you should first click on the ``Code`` button on the right, this will give you the ability to copy the HTTPS link of the repository. .. image:: _static/cloning_fork.png Once you copied the \, you can open a terminal on your machine, navigate to the place you want to download the repository to and type: - .. code-block:: shell + .. code-block:: shell $ git clone - This will create a `flower/` (or the name of your fork if you renamed it) folder in the current working directory. + This will create a ``flower/`` (or the name of your fork if you renamed it) folder in the current working directory. 4. **Add origin** You can then go into the repository folder: @@ -59,17 +57,17 @@ Setting up the repository To obtain it, we can do as previously mentioned by going to our fork repository on our GitHub account and copying the link. .. image:: _static/cloning_fork.png - + Once the \ is copied, we can type the following command in our terminal: .. code-block:: shell $ git remote add origin - + 5. **Add upstream** Now we will add an upstream address to our repository. - Still in the same directroy, we must run the following command: + Still in the same directory, we must run the following command: .. code-block:: shell @@ -77,10 +75,10 @@ Setting up the repository The following diagram visually explains what we did in the previous steps: - .. image:: _static/github_schema.png + .. image:: _static/github_schema.png - The upstream is the GitHub remote address of the parent repository (in this case Flower), - i.e. the one we eventually want to contribute to and therefore need an up-to-date history of. + The upstream is the GitHub remote address of the parent repository (in this case Flower), + i.e. the one we eventually want to contribute to and therefore need an up-to-date history of. The origin is just the GitHub remote address of the forked repository we created, i.e. the copy (fork) in our own account. To make sure our local version of the fork is up-to-date with the latest changes from the Flower repository, @@ -94,7 +92,7 @@ Setting up the repository Setting up the coding environment --------------------------------- -This can be achieved by following this `getting started guide for contributors`_ (note that you won't need to clone the repository). +This can be achieved by following this :doc:`getting started guide for contributors ` (note that you won't need to clone the repository). Once you are able to write code and test it, you can finally start making changes! @@ -114,9 +112,9 @@ And with Flower's repository: $ git pull upstream main 1. **Create a new branch** - To make the history cleaner and easier to work with, it is good practice to + To make the history cleaner and easier to work with, it is good practice to create a new branch for each feature/project that needs to be implemented. - + To do so, just run the following command inside the repository's directory: .. code-block:: shell @@ -138,7 +136,7 @@ And with Flower's repository: $ ./dev/test.sh # to test that your code can be accepted $ ./baselines/dev/format.sh # same as above but for code added to baselines $ ./baselines/dev/test.sh # same as above but for code added to baselines - + 4. **Stage changes** Before creating a commit that will update your history, you must specify to Git which files it needs to take into account. @@ -180,22 +178,26 @@ Creating and merging a pull request (PR) .. image:: _static/compare_and_pr.png - Otherwise you can always find this option in the `Branches` page. + Otherwise you can always find this option in the ``Branches`` page. - Once you click the `Compare & pull request` button, you should see something similar to this: + Once you click the ``Compare & pull request`` button, you should see something similar to this: .. image:: _static/creating_pr.png - + At the top you have an explanation of which branch will be merged where: .. image:: _static/merging_branch.png - + In this example you can see that the request is to merge the branch ``doc-fixes`` from my forked repository to branch ``main`` from the Flower repository. - The input box in the middle is there for you to describe what your PR does and to link it to existing issues. + The input box in the middle is there for you to describe what your PR does and to link it to existing issues. We have placed comments (that won't be rendered once the PR is opened) to guide you through the process. - At the bottom you will find the button to open the PR. This will notify reviewers that a new PR has been opened and + It is important to follow the instructions described in comments. For instance, in order to not break how our changelog system works, + you should read the information above the ``Changelog entry`` section carefully. + You can also checkout some examples and details in the :ref:`changelogentry` appendix. + + At the bottom you will find the button to open the PR. This will notify reviewers that a new PR has been opened and that they should look over it to merge or to request changes. If your PR is not yet ready for review, and you don't want to notify anyone, you have the option to create a draft pull request: @@ -215,7 +217,7 @@ Creating and merging a pull request (PR) Merging will be blocked if there are ongoing requested changes. .. image:: _static/changes_requested.png - + To resolve them, just push the necessary changes to the branch associated with the PR: .. image:: _static/make_changes.png @@ -253,54 +255,54 @@ Example of first contribution Problem ******* -For our documentation, we’ve started to use the `Diàtaxis framework `_. +For our documentation, we've started to use the `Diàtaxis framework `_. -Our “How to” guides should have titles that continue the sencence “How to …”, for example, “How to upgrade to Flower 1.0”. +Our "How to" guides should have titles that continue the sentence "How to …", for example, "How to upgrade to Flower 1.0". Most of our guides do not follow this new format yet, and changing their title is (unfortunately) more involved than one might think. -This issue is about changing the title of a doc from present continious to present simple. +This issue is about changing the title of a doc from present continuous to present simple. -Let's take the example of “Saving Progress” which we changed to “Save Progress”. Does this pass our check? +Let's take the example of "Saving Progress" which we changed to "Save Progress". Does this pass our check? -Before: ”How to saving progress” ❌ +Before: "How to saving progress" ❌ -After: ”How to save progress” ✅ +After: "How to save progress" ✅ Solution ******** -This is a tiny change, but it’ll allow us to test your end-to-end setup. After cloning and setting up the Flower repo, here’s what you should do: +This is a tiny change, but it'll allow us to test your end-to-end setup. After cloning and setting up the Flower repo, here's what you should do: -- Find the source file in `doc/source` -- Make the change in the `.rst` file (beware, the dashes under the title should be the same length as the title itself) -- Build the docs and check the result: ``_ +- Find the source file in ``doc/source`` +- Make the change in the ``.rst`` file (beware, the dashes under the title should be the same length as the title itself) +- Build the docs and `check the result `_ Rename file ::::::::::: -You might have noticed that the file name still reflects the old wording. +You might have noticed that the file name still reflects the old wording. If we just change the file, then we break all existing links to it - it is **very important** to avoid that, breaking links can harm our search engine ranking. -Here’s how to change the file name: +Here's how to change the file name: -- Change the file name to `save-progress.rst` -- Add a redirect rule to `doc/source/conf.py` +- Change the file name to ``save-progress.rst`` +- Add a redirect rule to ``doc/source/conf.py`` -This will cause a redirect from `saving-progress.html` to `save-progress.html`, old links will continue to work. +This will cause a redirect from ``saving-progress.html`` to ``save-progress.html``, old links will continue to work. Apply changes in the index file ::::::::::::::::::::::::::::::: -For the lateral navigation bar to work properly, it is very important to update the `index.rst` file as well. +For the lateral navigation bar to work properly, it is very important to update the ``index.rst`` file as well. This is where we define the whole arborescence of the navbar. -- Find and modify the file name in `index.rst` +- Find and modify the file name in ``index.rst`` Open PR ::::::: -- Commit the changes (commit messages are always imperative: “Do something”, in this case “Change …”) +- Commit the changes (commit messages are always imperative: "Do something", in this case "Change …") - Push the changes to your fork - Open a PR (as shown above) - Wait for it to be approved! @@ -331,7 +333,7 @@ Here are a few positive examples which provide helpful information without repea * Update docs banner to mention Flower Summit 2023 * Remove unnecessary XGBoost dependency * Remove redundant attributes in strategies subclassing FedAvg -* Add CI job to deploy the staging system when the `main` branch changes +* Add CI job to deploy the staging system when the ``main`` branch changes * Add new amazing library which will be used to improve the simulation engine @@ -340,4 +342,77 @@ Next steps Once you have made your first PR, and want to contribute more, be sure to check out the following : -- `Good first contributions `_, where you should particularly look into the :code:`baselines` contributions. +- :doc:`Good first contributions `, where you should particularly look into the :code:`baselines` contributions. + + +Appendix +-------- + +.. _changelogentry: + +Changelog entry +*************** + +When opening a new PR, inside its description, there should be a ``Changelog entry`` header. + +Above this header you should see the following comment that explains how to write your changelog entry: + + Inside the following 'Changelog entry' section, + you should put the description of your changes that will be added to the changelog alongside your PR title. + + If the section is completely empty (without any token) or non-existent, + the changelog will just contain the title of the PR for the changelog entry, without any description. + + If the section contains some text other than tokens, it will use it to add a description to the change. + + If the section contains one of the following tokens it will ignore any other text and put the PR under the corresponding section of the changelog: + + is for classifying a PR as a general improvement. + + is to not add the PR to the changelog + + is to add a general baselines change to the PR + + is to add a general examples change to the PR + + is to add a general sdk change to the PR + + is to add a general simulations change to the PR + + Note that only one token should be used. + +Its content must have a specific format. We will break down what each possibility does: + +- If the ``### Changelog entry`` section contains nothing or doesn't exist, the following text will be added to the changelog:: + + - **PR TITLE** ([#PR_NUMBER](https://github.com/adap/flower/pull/PR_NUMBER)) + +- If the ``### Changelog entry`` section contains a description (and no token), the following text will be added to the changelog:: + + - **PR TITLE** ([#PR_NUMBER](https://github.com/adap/flower/pull/PR_NUMBER)) + + DESCRIPTION FROM THE CHANGELOG ENTRY + +- If the ``### Changelog entry`` section contains ````, nothing will change in the changelog. + +- If the ``### Changelog entry`` section contains ````, the following text will be added to the changelog:: + + - **General improvements** ([#PR_NUMBER](https://github.com/adap/flower/pull/PR_NUMBER)) + +- If the ``### Changelog entry`` section contains ````, the following text will be added to the changelog:: + + - **General updates to Flower Baselines** ([#PR_NUMBER](https://github.com/adap/flower/pull/PR_NUMBER)) + +- If the ``### Changelog entry`` section contains ````, the following text will be added to the changelog:: + + - **General updates to Flower Examples** ([#PR_NUMBER](https://github.com/adap/flower/pull/PR_NUMBER)) + +- If the ``### Changelog entry`` section contains ````, the following text will be added to the changelog:: + + - **General updates to Flower SDKs** ([#PR_NUMBER](https://github.com/adap/flower/pull/PR_NUMBER)) + +- If the ``### Changelog entry`` section contains ````, the following text will be added to the changelog:: + + - **General updates to Flower Simulations** ([#PR_NUMBER](https://github.com/adap/flower/pull/PR_NUMBER)) + +Note that only one token must be provided, otherwise, only the first action (in the order listed above), will be performed. diff --git a/doc/source/contributor-tutorial-get-started-as-a-contributor.rst b/doc/source/contributor-tutorial-get-started-as-a-contributor.rst index 72c6df5fdbc7..9136fea96bf6 100644 --- a/doc/source/contributor-tutorial-get-started-as-a-contributor.rst +++ b/doc/source/contributor-tutorial-get-started-as-a-contributor.rst @@ -11,7 +11,7 @@ Prerequisites Flower uses :code:`pyproject.toml` to manage dependencies and configure development tools (the ones which support it). Poetry is a build tool which -supports `PEP 517 `_. +supports `PEP 517 `_. Developer Machine Setup @@ -27,7 +27,7 @@ For macOS * Install `homebrew `_. Don't forget the post-installation actions to add `brew` to your PATH. * Install `xz` (to install different Python versions) and `pandoc` to build the docs:: - + $ brew install xz pandoc For Ubuntu @@ -54,7 +54,7 @@ GitHub:: * If you don't have :code:`pyenv` installed, the following script that will install it, set it up, and create the virtual environment (with :code:`Python 3.8.17` by default):: $ ./dev/setup-defaults.sh # once completed, run the bootstrap script - + * If you already have :code:`pyenv` installed (along with the :code:`pyenv-virtualenv` plugin), you can use the following convenience script (with :code:`Python 3.8.17` by default):: $ ./dev/venv-create.sh # once completed, run the `bootstrap.sh` script @@ -70,7 +70,7 @@ Convenience Scripts The Flower repository contains a number of convenience scripts to make recurring development tasks easier and less error-prone. See the :code:`/dev` -subdirectory for a full list. The following scripts are amonst the most +subdirectory for a full list. The following scripts are amongst the most important ones: Create/Delete Virtual Environment diff --git a/doc/source/example-fedbn-pytorch-from-centralized-to-federated.rst b/doc/source/example-fedbn-pytorch-from-centralized-to-federated.rst index 5ebaa337dde8..0139f3b8dc31 100644 --- a/doc/source/example-fedbn-pytorch-from-centralized-to-federated.rst +++ b/doc/source/example-fedbn-pytorch-from-centralized-to-federated.rst @@ -3,11 +3,11 @@ Example: FedBN in PyTorch - From Centralized To Federated This tutorial will show you how to use Flower to build a federated version of an existing machine learning workload with `FedBN `_, a federated training strategy designed for non-iid data. We are using PyTorch to train a Convolutional Neural Network(with Batch Normalization layers) on the CIFAR-10 dataset. -When applying FedBN, only few changes needed compared to `Example: PyTorch - From Centralized To Federated `_. +When applying FedBN, only few changes needed compared to :doc:`Example: PyTorch - From Centralized To Federated `. Centralized Training -------------------- -All files are revised based on `Example: PyTorch - From Centralized To Federated `_. +All files are revised based on :doc:`Example: PyTorch - From Centralized To Federated `. The only thing to do is modifying the file called :code:`cifar.py`, revised part is shown below: The model architecture defined in class Net() is added with Batch Normalization layers accordingly. @@ -45,13 +45,13 @@ You can now run your machine learning workload: python3 cifar.py So far this should all look fairly familiar if you've used PyTorch before. -Let's take the next step and use what we've built to create a federated learning system within FedBN, the sytstem consists of one server and two clients. +Let's take the next step and use what we've built to create a federated learning system within FedBN, the system consists of one server and two clients. Federated Training ------------------ -If you have read `Example: PyTorch - From Centralized To Federated `_, the following parts are easy to follow, onyl :code:`get_parameters` and :code:`set_parameters` function in :code:`client.py` needed to revise. -If not, please read the `Example: PyTorch - From Centralized To Federated `_. first. +If you have read :doc:`Example: PyTorch - From Centralized To Federated `, the following parts are easy to follow, only :code:`get_parameters` and :code:`set_parameters` function in :code:`client.py` needed to revise. +If not, please read the :doc:`Example: PyTorch - From Centralized To Federated `. first. Our example consists of one *server* and two *clients*. In FedBN, :code:`server.py` keeps unchanged, we can start the server directly. @@ -66,7 +66,7 @@ Finally, we will revise our *client* logic by changing :code:`get_parameters` an class CifarClient(fl.client.NumPyClient): """Flower client implementing CIFAR-10 image classification using PyTorch.""" - + ... def get_parameters(self, config) -> List[np.ndarray]: @@ -79,7 +79,7 @@ Finally, we will revise our *client* logic by changing :code:`get_parameters` an params_dict = zip(keys, parameters) state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) self.model.load_state_dict(state_dict, strict=False) - + ... Now, you can now open two additional terminal windows and run diff --git a/doc/source/example-jax-from-centralized-to-federated.rst b/doc/source/example-jax-from-centralized-to-federated.rst index 2b1823e9d408..6b06a288a67a 100644 --- a/doc/source/example-jax-from-centralized-to-federated.rst +++ b/doc/source/example-jax-from-centralized-to-federated.rst @@ -259,7 +259,7 @@ Having defined the federation process, we can run it. # Start Flower client client = FlowerClient(params, grad_fn, train_x, train_y, test_x, test_y) - fl.client.start_numpy_client(server_address="0.0.0.0:8080", client) + fl.client.start_client(server_address="0.0.0.0:8080", client.to_client()) if __name__ == "__main__": main() diff --git a/doc/source/example-mxnet-walk-through.rst b/doc/source/example-mxnet-walk-through.rst index a395a8a723cb..c215f709ffb2 100644 --- a/doc/source/example-mxnet-walk-through.rst +++ b/doc/source/example-mxnet-walk-through.rst @@ -234,7 +234,7 @@ Our *client* needs to import :code:`flwr`, but also :code:`mxnet` to update the Implementing a Flower *client* basically means implementing a subclass of either :code:`flwr.client.Client` or :code:`flwr.client.NumPyClient`. Our implementation will be based on :code:`flwr.client.NumPyClient` and we'll call it :code:`MNISTClient`. -:code:`NumPyClient` is slighly easier to implement than :code:`Client` if you use a framework with good NumPy interoperability (like PyTorch or MXNet) because it avoids some of the boilerplate that would otherwise be necessary. +:code:`NumPyClient` is slightly easier to implement than :code:`Client` if you use a framework with good NumPy interoperability (like PyTorch or MXNet) because it avoids some of the boilerplate that would otherwise be necessary. :code:`MNISTClient` needs to implement four methods, two methods for getting/setting model parameters, one method for training the model, and one method for testing the model: #. :code:`set_parameters (optional)` diff --git a/doc/source/example-pytorch-from-centralized-to-federated.rst b/doc/source/example-pytorch-from-centralized-to-federated.rst index d649658667da..0c458a136a81 100644 --- a/doc/source/example-pytorch-from-centralized-to-federated.rst +++ b/doc/source/example-pytorch-from-centralized-to-federated.rst @@ -174,7 +174,7 @@ However, with Flower you can evolve your pre-existing code into a federated lear The concept is easy to understand. We have to start a *server* and then use the code in :code:`cifar.py` for the *clients* that are connected to the *server*. -The *server* sends model parameters to the clients. The *clients* run the training and update the paramters. +The *server* sends model parameters to the clients. The *clients* run the training and update the parameters. The updated parameters are sent back to the *server* which averages all received parameter updates. This describes one round of the federated learning process and we repeat this for multiple rounds. @@ -195,7 +195,7 @@ We can already start the *server*: python3 server.py Finally, we will define our *client* logic in :code:`client.py` and build upon the previously defined centralized training in :code:`cifar.py`. -Our *client* needs to import :code:`flwr`, but also :code:`torch` to update the paramters on our PyTorch model: +Our *client* needs to import :code:`flwr`, but also :code:`torch` to update the parameters on our PyTorch model: .. code-block:: python @@ -212,7 +212,7 @@ Our *client* needs to import :code:`flwr`, but also :code:`torch` to update the Implementing a Flower *client* basically means implementing a subclass of either :code:`flwr.client.Client` or :code:`flwr.client.NumPyClient`. Our implementation will be based on :code:`flwr.client.NumPyClient` and we'll call it :code:`CifarClient`. -:code:`NumPyClient` is slighly easier to implement than :code:`Client` if you use a framework with good NumPy interoperability (like PyTorch or TensorFlow/Keras) because it avoids some of the boilerplate that would otherwise be necessary. +:code:`NumPyClient` is slightly easier to implement than :code:`Client` if you use a framework with good NumPy interoperability (like PyTorch or TensorFlow/Keras) because it avoids some of the boilerplate that would otherwise be necessary. :code:`CifarClient` needs to implement four methods, two methods for getting/setting model parameters, one method for training the model, and one method for testing the model: #. :code:`set_parameters` @@ -278,7 +278,7 @@ We included type annotations to give you a better understanding of the data type return float(loss), self.num_examples["testset"], {"accuracy": float(accuracy)} All that's left to do it to define a function that loads both model and data, creates a :code:`CifarClient`, and starts this client. -You load your data and model by using :code:`cifar.py`. Start :code:`CifarClient` with the function :code:`fl.client.start_numpy_client()` by pointing it at the same IP adress we used in :code:`server.py`: +You load your data and model by using :code:`cifar.py`. Start :code:`CifarClient` with the function :code:`fl.client.start_client()` by pointing it at the same IP address we used in :code:`server.py`: .. code-block:: python @@ -292,7 +292,7 @@ You load your data and model by using :code:`cifar.py`. Start :code:`CifarClient # Start client client = CifarClient(model, trainloader, testloader, num_examples) - fl.client.start_numpy_client(server_address="0.0.0.0:8080", client) + fl.client.start_client(server_address="0.0.0.0:8080", client.to_client()) if __name__ == "__main__": diff --git a/doc/source/example-walkthrough-pytorch-mnist.rst b/doc/source/example-walkthrough-pytorch-mnist.rst index ab311813f5de..f8eacc8647fe 100644 --- a/doc/source/example-walkthrough-pytorch-mnist.rst +++ b/doc/source/example-walkthrough-pytorch-mnist.rst @@ -1,11 +1,11 @@ Example: Walk-Through PyTorch & MNIST ===================================== -In this tutorial we will learn, how to train a Convolutional Neural Network on MNIST using Flower and PyTorch. +In this tutorial we will learn, how to train a Convolutional Neural Network on MNIST using Flower and PyTorch. -Our example consists of one *server* and two *clients* all having the same model. +Our example consists of one *server* and two *clients* all having the same model. -*Clients* are responsible for generating individual weight-updates for the model based on their local datasets. +*Clients* are responsible for generating individual weight-updates for the model based on their local datasets. These updates are then sent to the *server* which will aggregate them to produce a better model. Finally, the *server* sends this improved version of the model back to each *client*. A complete cycle of weight updates is called a *round*. @@ -15,7 +15,7 @@ Now that we have a rough idea of what is going on, let's get started. We first n $ pip install flwr -Since we want to use PyTorch to solve a computer vision task, let's go ahead an install PyTorch and the **torchvision** library: +Since we want to use PyTorch to solve a computer vision task, let's go ahead an install PyTorch and the **torchvision** library: .. code-block:: shell @@ -32,51 +32,51 @@ Go ahead and launch on a terminal the *run-server.sh* script first as follows: .. code-block:: shell - $ bash ./run-server.sh + $ bash ./run-server.sh -Now that the server is up and running, go ahead and launch the clients. +Now that the server is up and running, go ahead and launch the clients. .. code-block:: shell - $ bash ./run-clients.sh + $ bash ./run-clients.sh Et voilà! You should be seeing the training procedure and, after a few iterations, the test accuracy for each client. .. code-block:: shell - Train Epoch: 10 [30000/30016 (100%)] Loss: 0.007014 - - Train Epoch: 10 [30000/30016 (100%)] Loss: 0.000403 - - Train Epoch: 11 [30000/30016 (100%)] Loss: 0.001280 - - Train Epoch: 11 [30000/30016 (100%)] Loss: 0.000641 - - Train Epoch: 12 [30000/30016 (100%)] Loss: 0.006784 - - Train Epoch: 12 [30000/30016 (100%)] Loss: 0.007134 - - Client 1 - Evaluate on 5000 samples: Average loss: 0.0290, Accuracy: 99.16% - + Train Epoch: 10 [30000/30016 (100%)] Loss: 0.007014 + + Train Epoch: 10 [30000/30016 (100%)] Loss: 0.000403 + + Train Epoch: 11 [30000/30016 (100%)] Loss: 0.001280 + + Train Epoch: 11 [30000/30016 (100%)] Loss: 0.000641 + + Train Epoch: 12 [30000/30016 (100%)] Loss: 0.006784 + + Train Epoch: 12 [30000/30016 (100%)] Loss: 0.007134 + + Client 1 - Evaluate on 5000 samples: Average loss: 0.0290, Accuracy: 99.16% + Client 0 - Evaluate on 5000 samples: Average loss: 0.0328, Accuracy: 99.14% -Now, let's see what is really happening inside. +Now, let's see what is really happening inside. Flower Server ------------- Inside the server helper script *run-server.sh* you will find the following code that basically runs the :code:`server.py` -.. code-block:: bash +.. code-block:: bash python -m flwr_example.quickstart-pytorch.server We can go a bit deeper and see that :code:`server.py` simply launches a server that will coordinate three rounds of training. -Flower Servers are very customizable, but for simple workloads, we can start a server using the :ref:`start_server ` function and leave all the configuration possibilities at their default values, as seen below. +Flower Servers are very customizable, but for simple workloads, we can start a server using the `start_server `_ function and leave all the configuration possibilities at their default values, as seen below. .. code-block:: python @@ -90,18 +90,18 @@ Flower Client Next, let's take a look at the *run-clients.sh* file. You will see that it contains the main loop that starts a set of *clients*. -.. code-block:: bash +.. code-block:: bash python -m flwr_example.quickstart-pytorch.client \ --cid=$i \ --server_address=$SERVER_ADDRESS \ - --nb_clients=$NUM_CLIENTS + --nb_clients=$NUM_CLIENTS * **cid**: is the client ID. It is an integer that uniquely identifies client identifier. -* **sever_address**: String that identifies IP and port of the server. +* **sever_address**: String that identifies IP and port of the server. * **nb_clients**: This defines the number of clients being created. This piece of information is not required by the client, but it helps us partition the original MNIST dataset to make sure that every client is working on unique subsets of both *training* and *test* sets. -Again, we can go deeper and look inside :code:`flwr_example/quickstart-pytorch/client.py`. +Again, we can go deeper and look inside :code:`flwr_example/quickstart-pytorch/client.py`. After going through the argument parsing code at the beginning of our :code:`main` function, you will find a call to :code:`mnist.load_data`. This function is responsible for partitioning the original MNIST datasets (*training* and *test*) and returning a :code:`torch.utils.data.DataLoader` s for each of them. We then instantiate a :code:`PytorchMNISTClient` object with our client ID, our DataLoaders, the number of epochs in each round, and which device we want to use for training (CPU or GPU). @@ -152,7 +152,7 @@ Now, let's look closely into the :code:`PytorchMNISTClient` inside :code:`flwr_e Parameters ---------- - weights: fl.common.NDArrays + weights: fl.common.NDArrays Weights received by the server and set to local model @@ -179,8 +179,8 @@ Now, let's look closely into the :code:`PytorchMNISTClient` inside :code:`flwr_e Parameters ---------- - ins: fl.common.FitIns - Parameters sent by the server to be used during training. + ins: fl.common.FitIns + Parameters sent by the server to be used during training. Returns ------- @@ -214,9 +214,9 @@ Now, let's look closely into the :code:`PytorchMNISTClient` inside :code:`flwr_e Parameters ---------- - ins: fl.common.EvaluateIns - Parameters sent by the server to be used during testing. - + ins: fl.common.EvaluateIns + Parameters sent by the server to be used during testing. + Returns ------- @@ -262,9 +262,9 @@ The code for the CNN is available under :code:`quickstart-pytorch.mnist` and it Parameters ---------- - x: Tensor + x: Tensor Mini-batch of shape (N,28,28) containing images from MNIST dataset. - + Returns ------- @@ -287,7 +287,7 @@ The code for the CNN is available under :code:`quickstart-pytorch.mnist` and it return output -The second thing to notice is that :code:`PytorchMNISTClient` class inherits from the :code:`fl.client.Client`, and hence it must implement the following methods: +The second thing to notice is that :code:`PytorchMNISTClient` class inherits from the :code:`fl.client.Client`, and hence it must implement the following methods: .. code-block:: python @@ -312,7 +312,7 @@ The second thing to notice is that :code:`PytorchMNISTClient` class inherits fro """Evaluate the provided weights using the locally held dataset.""" -When comparing the abstract class to its derived class :code:`PytorchMNISTClient` you will notice that :code:`fit` calls a :code:`train` function and that :code:`evaluate` calls a :code:`test`: function. +When comparing the abstract class to its derived class :code:`PytorchMNISTClient` you will notice that :code:`fit` calls a :code:`train` function and that :code:`evaluate` calls a :code:`test`: function. These functions can both be found inside the same :code:`quickstart-pytorch.mnist` module: @@ -330,21 +330,21 @@ These functions can both be found inside the same :code:`quickstart-pytorch.mnis ---------- model: torch.nn.ModuleList Neural network model used in this example. - + train_loader: torch.utils.data.DataLoader - DataLoader used in traning. - - epochs: int - Number of epochs to run in each round. - - device: torch.device + DataLoader used in training. + + epochs: int + Number of epochs to run in each round. + + device: torch.device (Default value = torch.device("cpu")) Device where the network will be trained within a client. Returns ------- num_examples_train: int - Number of total samples used during traning. + Number of total samples used during training. """ model.train() @@ -399,10 +399,10 @@ These functions can both be found inside the same :code:`quickstart-pytorch.mnis ---------- model: torch.nn.ModuleList : Neural network model used in this example. - + test_loader: torch.utils.data.DataLoader : DataLoader used in test. - + device: torch.device : (Default value = torch.device("cpu")) Device where the network will be tested within a client. @@ -435,19 +435,19 @@ These functions can both be found inside the same :code:`quickstart-pytorch.mnis Observe that these functions encapsulate regular training and test loops and provide :code:`fit` and :code:`evaluate` with final statistics for each round. -You could substitute them with your custom train and test loops and change the network architecture, and the entire example would still work flawlessly. -As a matter of fact, why not try and modify the code to an example of your liking? +You could substitute them with your custom train and test loops and change the network architecture, and the entire example would still work flawlessly. +As a matter of fact, why not try and modify the code to an example of your liking? Give It a Try ------------- -Looking through the quickstart code description above will have given a good understanding of how *clients* and *servers* work in Flower, how to run a simple experiment, and the internals of a client wrapper. +Looking through the quickstart code description above will have given a good understanding of how *clients* and *servers* work in Flower, how to run a simple experiment, and the internals of a client wrapper. Here are a few things you could try on your own and get more experience with Flower: - Try and change :code:`PytorchMNISTClient` so it can accept different architectures. - Modify the :code:`train` function so that it accepts different optimizers - Modify the :code:`test` function so that it proves not only the top-1 (regular accuracy) but also the top-5 accuracy? -- Go larger! Try to adapt the code to larger images and datasets. Why not try training on ImageNet with a ResNet-50? +- Go larger! Try to adapt the code to larger images and datasets. Why not try training on ImageNet with a ResNet-50? You are ready now. Enjoy learning in a federated way! diff --git a/doc/source/explanation-differential-privacy.rst b/doc/source/explanation-differential-privacy.rst index fcebd81983cc..69fd333f9b13 100644 --- a/doc/source/explanation-differential-privacy.rst +++ b/doc/source/explanation-differential-privacy.rst @@ -1,100 +1,139 @@ -Differential privacy +Differential Privacy ==================== +The information in datasets like healthcare, financial transactions, user preferences, etc., is valuable and has the potential for scientific breakthroughs and provides important business insights. +However, such data is also sensitive and there is a risk of compromising individual privacy. -Flower provides differential privacy (DP) wrapper classes for the easy integration of the central DP guarantees provided by DP-FedAvg into training pipelines defined in any of the various ML frameworks that Flower is compatible with. +Traditional methods like anonymization alone would not work because of attacks like Re-identification and Data Linkage. +That's where differential privacy comes in. It provides the possibility of analyzing data while ensuring the privacy of individuals. -.. warning:: - Please note that these components are still experimental; the correct configuration of DP for a specific task is still an unsolved problem. -.. note:: - The name DP-FedAvg is misleading since it can be applied on top of any FL algorithm that conforms to the general structure prescribed by the FedOpt family of algorithms. +Differential Privacy +-------------------- +Imagine two datasets that are identical except for a single record (for instance, Alice's data). +Differential Privacy (DP) guarantees that any analysis (M), like calculating the average income, will produce nearly identical results for both datasets (O and O' would be similar). +This preserves group patterns while obscuring individual details, ensuring the individual's information remains hidden in the crowd. -DP-FedAvg ---------- +.. image:: ./_static/DP/dp-intro.png + :align: center + :width: 400 + :alt: DP Intro -DP-FedAvg, originally proposed by McMahan et al. [mcmahan]_ and extended by Andrew et al. [andrew]_, is essentially FedAvg with the following modifications. -* **Clipping** : The influence of each client's update is bounded by clipping it. This is achieved by enforcing a cap on the L2 norm of the update, scaling it down if needed. -* **Noising** : Gaussian noise, calibrated to the clipping threshold, is added to the average computed at the server. +One of the most commonly used mechanisms to achieve DP is adding enough noise to the output of the analysis to mask the contribution of each individual in the data while preserving the overall accuracy of the analysis. -The distribution of the update norm has been shown to vary from task-to-task and to evolve as training progresses. This variability is crucial in understanding its impact on differential privacy guarantees, emphasizing the need for an adaptive approach [andrew]_ that continuously adjusts the clipping threshold to track a prespecified quantile of the update norm distribution. +Formal Definition +~~~~~~~~~~~~~~~~~ +Differential Privacy (DP) provides statistical guarantees against the information an adversary can infer through the output of a randomized algorithm. +It provides an unconditional upper bound on the influence of a single individual on the output of the algorithm by adding noise [1]. +A randomized mechanism +M provides (:math:`\epsilon`, :math:`\delta`)-differential privacy if for any two neighboring databases, D :sub:`1` and D :sub:`2`, that differ in only a single record, +and for all possible outputs S ⊆ Range(A): -Simplifying Assumptions -*********************** +.. math:: -We make (and attempt to enforce) a number of assumptions that must be satisfied to ensure that the training process actually realizes the :math:`(\epsilon, \delta)` guarantees the user has in mind when configuring the setup. + \small + P[M(D_{1} \in A)] \leq e^{\delta} P[M(D_{2} \in A)] + \delta -* **Fixed-size subsampling** :Fixed-size subsamples of the clients must be taken at each round, as opposed to variable-sized Poisson subsamples. -* **Unweighted averaging** : The contributions from all the clients must weighted equally in the aggregate to eliminate the requirement for the server to know in advance the sum of the weights of all clients available for selection. -* **No client failures** : The set of available clients must stay constant across all rounds of training. In other words, clients cannot drop out or fail. -The first two are useful for eliminating a multitude of complications associated with calibrating the noise to the clipping threshold, while the third one is required to comply with the assumptions of the privacy analysis. +The :math:`\epsilon` parameter, also known as the privacy budget, is a metric of privacy loss. +It also controls the privacy-utility trade-off; lower :math:`\epsilon` values indicate higher levels of privacy but are likely to reduce utility as well. +The :math:`\delta` parameter accounts for a small probability on which the upper bound :math:`\epsilon` does not hold. +The amount of noise needed to achieve differential privacy is proportional to the sensitivity of the output, which measures the maximum change in the output due to the inclusion or removal of a single record. -.. note:: - These restrictions are in line with constraints imposed by Andrew et al. [andrew]_. -Customizable Responsibility for Noise injection -*********************************************** -In contrast to other implementations where the addition of noise is performed at the server, you can configure the site of noise injection to better match your threat model. We provide users with the flexibility to set up the training such that each client independently adds a small amount of noise to the clipped update, with the result that simply aggregating the noisy updates is equivalent to the explicit addition of noise to the non-noisy aggregate at the server. +Differential Privacy in Machine Learning +---------------------------------------- +DP can be utilized in machine learning to preserve the privacy of the training data. +Differentially private machine learning algorithms are designed in a way to prevent the algorithm to learn any specific information about any individual data points and subsequently prevent the model from revealing sensitive information. +Depending on the stage at which noise is introduced, various methods exist for applying DP to machine learning algorithms. +One approach involves adding noise to the training data (either to the features or labels), while another method entails injecting noise into the gradients of the loss function during model training. +Additionally, such noise can be incorporated into the model's output. +Differential Privacy in Federated Learning +------------------------------------------ +Federated learning is a data minimization approach that allows multiple parties to collaboratively train a model without sharing their raw data. +However, federated learning also introduces new privacy challenges. The model updates between parties and the central server can leak information about the local data. +These leaks can be exploited by attacks such as membership inference and property inference attacks, or model inversion attacks. -To be precise, if we let :math:`m` be the number of clients sampled each round and :math:`\sigma_\Delta` be the scale of the total Gaussian noise that needs to be added to the sum of the model updates, we can use simple maths to show that this is equivalent to each client adding noise with scale :math:`\sigma_\Delta/\sqrt{m}`. +DP can play a crucial role in federated learning to provide privacy for the clients' data. -Wrapper-based approach ----------------------- +Depending on the granularity of privacy provision or the location of noise addition, different forms of DP exist in federated learning. +In this explainer, we focus on two approaches of DP utilization in federated learning based on where the noise is added: at the server (also known as the center) or at the client (also known as the local). -Introducing DP to an existing workload can be thought of as adding an extra layer of security around it. This inspired us to provide the additional server and client-side logic needed to make the training process differentially private as wrappers for instances of the :code:`Strategy` and :code:`NumPyClient` abstract classes respectively. This wrapper-based approach has the advantage of being easily composable with other wrappers that someone might contribute to the Flower library in the future, e.g., for secure aggregation. Using Inheritance instead can be tedious because that would require the creation of new sub- classes every time a new class implementing :code:`Strategy` or :code:`NumPyClient` is defined. +- **Central Differential Privacy**: DP is applied by the server and the goal is to prevent the aggregated model from leaking information about each client's data. -Server-side logic -***************** +- **Local Differential Privacy**: DP is applied on the client side before sending any information to the server and the goal is to prevent the updates that are sent to the server from leaking any information about the client's data. -The first version of our solution was to define a decorator whose constructor accepted, among other things, a boolean-valued variable indicating whether adaptive clipping was to be enabled or not. We quickly realized that this would clutter its :code:`__init__()` function with variables corresponding to hyperparameters of adaptive clipping that would remain unused when it was disabled. A cleaner implementation could be achieved by splitting the functionality into two decorators, :code:`DPFedAvgFixed` and :code:`DPFedAvgAdaptive`, with the latter sub- classing the former. The constructors for both classes accept a boolean parameter :code:`server_side_noising`, which, as the name suggests, determines where noising is to be performed. +Central Differential Privacy +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +In this approach, which is also known as user-level DP, the central server is responsible for adding noise to the globally aggregated parameters. It should be noted that trust in the server is required. -DPFedAvgFixed -::::::::::::: +.. image:: ./_static/DP/CDP.png + :align: center + :width: 400 + :alt: Central Differential Privacy -The server-side capabilities required for the original version of DP-FedAvg, i.e., the one which performed fixed clipping, can be completely captured with the help of wrapper logic for just the following two methods of the :code:`Strategy` abstract class. +While there are various ways to implement central DP in federated learning, we concentrate on the algorithms proposed by [2] and [3]. +The overall approach is to clip the model updates sent by the clients and add some amount of noise to the aggregated model. +In each iteration, a random set of clients is chosen with a specific probability for training. +Each client performs local training on its own data. +The update of each client is then clipped by some value `S` (sensitivity `S`). +This would limit the impact of any individual client which is crucial for privacy and often beneficial for robustness. +A common approach to achieve this is by restricting the `L2` norm of the clients' model updates, ensuring that larger updates are scaled down to fit within the norm `S`. -#. :code:`configure_fit()` : The config dictionary being sent by the wrapped :code:`Strategy` to each client needs to be augmented with an additional value equal to the clipping threshold (keyed under :code:`dpfedavg_clip_norm`) and, if :code:`server_side_noising=true`, another one equal to the scale of the Gaussian noise that needs to be added at the client (keyed under :code:`dpfedavg_noise_stddev`). This entails *post*-processing of the results returned by the wrappee's implementation of :code:`configure_fit()`. -#. :code:`aggregate_fit()`: We check whether any of the sampled clients dropped out or failed to upload an update before the round timed out. In that case, we need to abort the current round, discarding any successful updates that were received, and move on to the next one. On the other hand, if all clients responded successfully, we must force the averaging of the updates to happen in an unweighted manner by intercepting the :code:`parameters` field of :code:`FitRes` for each received update and setting it to 1. Furthermore, if :code:`server_side_noising=true`, each update is perturbed with an amount of noise equal to what it would have been subjected to had client-side noising being enabled. This entails *pre*-processing of the arguments to this method before passing them on to the wrappee's implementation of :code:`aggregate_fit()`. +.. image:: ./_static/DP/clipping.png + :align: center + :width: 300 + :alt: clipping -.. note:: - We can't directly change the aggregation function of the wrapped strategy to force it to add noise to the aggregate, hence we simulate client-side noising to implement server-side noising. +Afterwards, the Gaussian mechanism is used to add noise in order to distort the sum of all clients' updates. +The amount of noise is scaled to the sensitivity value to obtain a privacy guarantee. +The Gaussian mechanism is used with a noise sampled from `N (0, σ²)` where `σ = ( noise_scale * S ) / (number of sampled clients)`. -These changes have been put together into a class called :code:`DPFedAvgFixed`, whose constructor accepts the strategy being decorated, the clipping threshold and the number of clients sampled every round as compulsory arguments. The user is expected to specify the clipping threshold since the order of magnitude of the update norms is highly dependent on the model being trained and providing a default value would be misleading. The number of clients sampled at every round is required to calculate the amount of noise that must be added to each individual update, either by the server or the clients. +Clipping +^^^^^^^^ -DPFedAvgAdaptive -:::::::::::::::: +There are two forms of clipping commonly used in Central DP: Fixed Clipping and Adaptive Clipping. -The additional functionality required to facilitate adaptive clipping has been provided in :code:`DPFedAvgAdaptive`, a subclass of :code:`DPFedAvgFixed`. It overrides the above-mentioned methods to do the following. +- **Fixed Clipping** : A predefined fix threshold is set for the magnitude of clients' updates. Any update exceeding this threshold is clipped back to the threshold value. -#. :code:`configure_fit()` : It intercepts the config dict returned by :code:`super.configure_fit()` to add the key-value pair :code:`dpfedavg_adaptive_clip_enabled:True` to it, which the client interprets as an instruction to include an indicator bit (1 if update norm <= clipping threshold, 0 otherwise) in the results returned by it. +- **Adaptive Clipping** : The clipping threshold dynamically adjusts based on the observed update distribution [4]. It means that the clipping value is tuned during the rounds with respect to the quantile of the update norm distribution. -#. :code:`aggregate_fit()` : It follows a call to :code:`super.aggregate_fit()` with one to :code:`__update_clip_norm__()`, a procedure which adjusts the clipping threshold on the basis of the indicator bits received from the sampled clients. +The choice between fixed and adaptive clipping depends on various factors such as privacy requirements, data distribution, model complexity, and others. +Local Differential Privacy +~~~~~~~~~~~~~~~~~~~~~~~~~~ -Client-side logic -***************** +In this approach, each client is responsible for performing DP. +Local DP avoids the need for a fully trusted aggregator, but it should be noted that local DP leads to a decrease in accuracy but better privacy in comparison to central DP. -The client-side capabilities required can be completely captured through wrapper logic for just the :code:`fit()` method of the :code:`NumPyClient` abstract class. To be precise, we need to *post-process* the update computed by the wrapped client to clip it, if necessary, to the threshold value supplied by the server as part of the config dictionary. In addition to this, it may need to perform some extra work if either (or both) of the following keys are also present in the dict. +.. image:: ./_static/DP/LDP.png + :align: center + :width: 400 + :alt: Local Differential Privacy -* :code:`dpfedavg_noise_stddev` : Generate and add the specified amount of noise to the clipped update. -* :code:`dpfedavg_adaptive_clip_enabled` : Augment the metrics dict in the :code:`FitRes` object being returned to the server with an indicator bit, calculated as described earlier. +In this explainer, we focus on two forms of achieving Local DP: -Performing the :math:`(\epsilon, \delta)` analysis --------------------------------------------------- +- Each client adds noise to the local updates before sending them to the server. To achieve (:math:`\epsilon`, :math:`\delta`)-DP, considering the sensitivity of the local model to be ∆, Gaussian noise is applied with a noise scale of σ where: -Assume you have trained for :math:`n` rounds with sampling fraction :math:`q` and noise multiplier :math:`z`. In order to calculate the :math:`\epsilon` value this would result in for a particular :math:`\delta`, the following script may be used. +.. math:: + \small + \frac{∆ \times \sqrt{2 \times \log\left(\frac{1.25}{\delta}\right)}}{\epsilon} -.. code-block:: python - import tensorflow_privacy as tfp - max_order = 32 - orders = range(2, max_order + 1) - rdp = tfp.compute_rdp_sample_without_replacement(q, z, n, orders) - eps, _, _ = tfp.rdp_accountant.get_privacy_spent(rdp, target_delta=delta) +- Each client adds noise to the gradients of the model during the local training (DP-SGD). More specifically, in this approach, gradients are clipped and an amount of calibrated noise is injected into the gradients. -.. [mcmahan] McMahan et al. "Learning Differentially Private Recurrent Language Models." International Conference on Learning Representations (ICLR), 2017. -.. [andrew] Andrew, Galen, et al. "Differentially Private Learning with Adaptive Clipping." Advances in Neural Information Processing Systems (NeurIPS), 2021. +Please note that these two approaches are providing privacy at different levels. + + +**References:** + +[1] Dwork et al. The Algorithmic Foundations of Differential Privacy. + +[2] McMahan et al. Learning Differentially Private Recurrent Language Models. + +[3] Geyer et al. Differentially Private Federated Learning: A Client Level Perspective. + +[4] Galen et al. Differentially Private Learning with Adaptive Clipping. diff --git a/doc/source/explanation-federated-evaluation.rst b/doc/source/explanation-federated-evaluation.rst index 632241f70d36..bcdca9bae700 100644 --- a/doc/source/explanation-federated-evaluation.rst +++ b/doc/source/explanation-federated-evaluation.rst @@ -120,7 +120,7 @@ Federated evaluation can be configured from the server side. Built-in strategies # Create strategy strategy = fl.server.strategy.FedAvg( - # ... other FedAvg agruments + # ... other FedAvg arguments fraction_evaluate=0.2, min_evaluate_clients=2, min_available_clients=10, diff --git a/doc/source/how-to-configure-clients.rst b/doc/source/how-to-configure-clients.rst index 26c132125ccf..ff0a2f4033df 100644 --- a/doc/source/how-to-configure-clients.rst +++ b/doc/source/how-to-configure-clients.rst @@ -13,7 +13,7 @@ Configuration values are represented as a dictionary with ``str`` keys and value config_dict = { "dropout": True, # str key, bool value "learning_rate": 0.01, # str key, float value - "batch_size": 32, # str key, int value + "batch_size": 32, # str key, int value "optimizer": "sgd", # str key, str value } @@ -56,7 +56,7 @@ To make the built-in strategies use this function, we can pass it to ``FedAvg`` One the client side, we receive the configuration dictionary in ``fit``: .. code-block:: python - + class FlowerClient(flwr.client.NumPyClient): def fit(parameters, config): print(config["batch_size"]) # Prints `32` @@ -86,7 +86,7 @@ Configuring individual clients In some cases, it is necessary to send different configuration values to different clients. -This can be achieved by customizing an existing strategy or by `implementing a custom strategy from scratch `_. Here's a nonsensical example that customizes :code:`FedAvg` by adding a custom ``"hello": "world"`` configuration key/value pair to the config dict of a *single client* (only the first client in the list, the other clients in this round to not receive this "special" config value): +This can be achieved by customizing an existing strategy or by :doc:`implementing a custom strategy from scratch `. Here's a nonsensical example that customizes :code:`FedAvg` by adding a custom ``"hello": "world"`` configuration key/value pair to the config dict of a *single client* (only the first client in the list, the other clients in this round to not receive this "special" config value): .. code-block:: python diff --git a/doc/source/how-to-configure-logging.rst b/doc/source/how-to-configure-logging.rst index 3fcfe85b9592..d5559429a73c 100644 --- a/doc/source/how-to-configure-logging.rst +++ b/doc/source/how-to-configure-logging.rst @@ -129,4 +129,4 @@ Log to a remote service The :code:`fl.common.logger.configure` function, also allows specifying a host to which logs can be pushed (via :code:`POST`) through a native Python :code:`logging.handler.HTTPHandler`. This is a particularly useful feature in :code:`gRPC`-based Federated Learning workloads where otherwise gathering logs from all entities (i.e. the server and the clients) might be cumbersome. -Note that in Flower simulation, the server automatically displays all logs. You can still specify a :code:`HTTPHandler` should you whish to backup or analyze the logs somewhere else. +Note that in Flower simulation, the server automatically displays all logs. You can still specify a :code:`HTTPHandler` should you wish to backup or analyze the logs somewhere else. diff --git a/doc/source/how-to-enable-ssl-connections.rst b/doc/source/how-to-enable-ssl-connections.rst index fa59d4423c5a..051dd5711497 100644 --- a/doc/source/how-to-enable-ssl-connections.rst +++ b/doc/source/how-to-enable-ssl-connections.rst @@ -75,9 +75,9 @@ We are now going to show how to write a client which uses the previously generat client = MyFlowerClient() # Start client - fl.client.start_numpy_client( + fl.client.start_client( "localhost:8080", - client=client, + client=client.to_client(), root_certificates=Path(".cache/certificates/ca.crt").read_bytes(), ) diff --git a/doc/source/how-to-implement-strategies.rst b/doc/source/how-to-implement-strategies.rst index 7997503b65a8..01bbb3042973 100644 --- a/doc/source/how-to-implement-strategies.rst +++ b/doc/source/how-to-implement-strategies.rst @@ -233,7 +233,7 @@ The return value is a list of tuples, each representing the instructions that wi * Use the :code:`client_manager` to randomly sample all (or a subset of) available clients (each represented as a :code:`ClientProxy` object) * Pair each :code:`ClientProxy` with the same :code:`FitIns` holding the current global model :code:`parameters` and :code:`config` dict -More sophisticated implementations can use :code:`configure_fit` to implement custom client selection logic. A client will only participate in a round if the corresponding :code:`ClientProxy` is included in the the list returned from :code:`configure_fit`. +More sophisticated implementations can use :code:`configure_fit` to implement custom client selection logic. A client will only participate in a round if the corresponding :code:`ClientProxy` is included in the list returned from :code:`configure_fit`. .. note:: @@ -280,7 +280,7 @@ The return value is a list of tuples, each representing the instructions that wi * Use the :code:`client_manager` to randomly sample all (or a subset of) available clients (each represented as a :code:`ClientProxy` object) * Pair each :code:`ClientProxy` with the same :code:`EvaluateIns` holding the current global model :code:`parameters` and :code:`config` dict -More sophisticated implementations can use :code:`configure_evaluate` to implement custom client selection logic. A client will only participate in a round if the corresponding :code:`ClientProxy` is included in the the list returned from :code:`configure_evaluate`. +More sophisticated implementations can use :code:`configure_evaluate` to implement custom client selection logic. A client will only participate in a round if the corresponding :code:`ClientProxy` is included in the list returned from :code:`configure_evaluate`. .. note:: diff --git a/doc/source/how-to-install-flower.rst b/doc/source/how-to-install-flower.rst index b2efde176fc9..aebe5f7316de 100644 --- a/doc/source/how-to-install-flower.rst +++ b/doc/source/how-to-install-flower.rst @@ -11,6 +11,9 @@ Flower requires at least `Python 3.8 `_, but `Pyth Install stable release ---------------------- +Using pip +~~~~~~~~~ + Stable releases are available on `PyPI `_:: python -m pip install flwr @@ -20,10 +23,29 @@ For simulations that use the Virtual Client Engine, ``flwr`` should be installed python -m pip install flwr[simulation] +Using conda (or mamba) +~~~~~~~~~~~~~~~~~~~~~~ + +Flower can also be installed from the ``conda-forge`` channel. + +If you have not added ``conda-forge`` to your channels, you will first need to run the following:: + + conda config --add channels conda-forge + conda config --set channel_priority strict + +Once the ``conda-forge`` channel has been enabled, ``flwr`` can be installed with ``conda``:: + + conda install flwr + +or with ``mamba``:: + + mamba install flwr + + Verify installation ------------------- -The following command can be used to verfiy if Flower was successfully installed. If everything worked, it should print the version of Flower to the command line:: +The following command can be used to verify if Flower was successfully installed. If everything worked, it should print the version of Flower to the command line:: python -c "import flwr;print(flwr.__version__)" 1.5.0 @@ -32,6 +54,11 @@ The following command can be used to verfiy if Flower was successfully installed Advanced installation options ----------------------------- +Install via Docker +~~~~~~~~~~~~~~~~~~ + +:doc:`How to run Flower using Docker ` + Install pre-release ~~~~~~~~~~~~~~~~~~~ diff --git a/doc/source/how-to-monitor-simulation.rst b/doc/source/how-to-monitor-simulation.rst index 740004914eed..f6c26a701d94 100644 --- a/doc/source/how-to-monitor-simulation.rst +++ b/doc/source/how-to-monitor-simulation.rst @@ -161,7 +161,7 @@ However, you can overwrite the defaults. When starting a simulation, do the foll ram_memory = 16_000 * 1024 * 1024 # 16 GB fl.simulation.start_simulation( # ... - # all the args you were specyfing before + # all the args you were specifying before # ... ray_init_args = { "include_dashboard": True, # we need this one for tracking @@ -187,7 +187,7 @@ Let’s also specify the resource for a single client. fl.simulation.start_simulation( # ... - # all the args you were specyfing before + # all the args you were specifying before # ... ray_init_args = { "include_dashboard": True, # we need this one for tracking @@ -231,6 +231,6 @@ A: Either the simulation has already finished, or you still need to start Promet Resources --------- -Ray Dashboard: ``_ +Ray Dashboard: ``_ -Ray Metrics: ``_ +Ray Metrics: ``_ diff --git a/doc/source/how-to-run-flower-using-docker.rst b/doc/source/how-to-run-flower-using-docker.rst new file mode 100644 index 000000000000..ed034c820142 --- /dev/null +++ b/doc/source/how-to-run-flower-using-docker.rst @@ -0,0 +1,144 @@ +Run Flower using Docker +======================= + +The simplest way to get started with Flower is by using the pre-made Docker images, which you can +find on `Docker Hub `_. + +Before you start, make sure that the Docker daemon is running: + +.. code-block:: bash + + $ docker -v + Docker version 24.0.7, build afdd53b + +If you do not see the version of Docker but instead get an error saying that the command +was not found, you will need to install Docker first. You can find installation instruction +`here `_. + +.. note:: + + On Linux, Docker commands require ``sudo`` privilege. If you want to avoid using ``sudo``, + you can follow the `Post-installation steps `_ + on the official Docker website. + +Flower server +------------- + +Quickstart +~~~~~~~~~~ + +If you're looking to try out Flower, you can use the following command: + +.. code-block:: bash + + $ docker run --rm -p 9091:9091 -p 9092:9092 flwr/server:1.7.0-py3.11-ubuntu22.04 \ + --insecure + +The command will pull the Docker image with the tag ``1.7.0-py3.11-ubuntu22.04`` from Docker Hub. +The tag contains the information which Flower, Python and Ubuntu is used. In this case, it +uses Flower 1.7.0, Python 3.11 and Ubuntu 22.04. The ``--rm`` flag tells Docker to remove +the container after it exits. + +.. note:: + + By default, the Flower server keeps state in-memory. When using the Docker flag + ``--rm``, the state is not persisted between container starts. We will show below how to save the + state in a file on your host system. + +The ``-p :`` flag tells Docker to map the ports ``9091``/``9092`` of the host to +``9091``/``9092`` of the container, allowing you to access the Driver API on ``http://localhost:9091`` +and the Fleet API on ``http://localhost:9092``. Lastly, any flag that comes after the tag is passed +to the Flower server. Here, we are passing the flag ``--insecure``. + +.. attention:: + + The ``--insecure`` flag enables insecure communication (using HTTP, not HTTPS) and should only be used + for testing purposes. We strongly recommend enabling + `SSL `_ + when deploying to a production environment. + +You can use ``--help`` to view all available flags that the server supports: + +.. code-block:: bash + + $ docker run --rm flwr/server:1.7.0-py3.11-ubuntu22.04 --help + +Mounting a volume to store the state on the host system +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If you want to persist the state of the server on your host system, all you need to do is specify a +path where you want to save the file on your host system and a name for the database file. In the +example below, we tell Docker via the flag ``-v`` to mount the user's home directory +(``~/`` on your host) into the ``/app/`` directory of the container. Furthermore, we use the +flag ``--database`` to specify the name of the database file. + +.. code-block:: bash + + $ docker run --rm \ + -p 9091:9091 -p 9092:9092 -v ~/:/app/ flwr/server:1.7.0-py3.11-ubuntu22.04 \ + --insecure \ + --database state.db + +As soon as the server starts, the file ``state.db`` is created in the user's home directory on +your host system. If the file already exists, the server tries to restore the state from the file. +To start the server with an empty database, simply remove the ``state.db`` file. + +Enabling SSL for secure connections +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To enable SSL, you will need a CA certificate, a server certificate and a server private key. + +.. note:: + For testing purposes, you can generate your own self-signed certificates. The + `Enable SSL connections `_ + page contains a section that will guide you through the process. + +Assuming all files we need are in the local ``certificates`` directory, we can use the flag +``-v`` to mount the local directory into the ``/app/`` directory of the container. This allows the +server to access the files within the container. Finally, we pass the names of the certificates to +the server with the ``--certificates`` flag. + +.. code-block:: bash + + $ docker run --rm \ + -p 9091:9091 -p 9092:9092 -v ./certificates/:/app/ flwr/server:1.7.0-py3.11-ubuntu22.04 \ + --certificates ca.crt server.pem server.key + +Using a different Flower or Python version +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If you want to use a different version of Flower or Python, you can do so by changing the tag. +All versions we provide are available on `Docker Hub `_. + +Pinning a Docker image to a specific version +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +It may happen that we update the images behind the tags. Such updates usually include security +updates of system dependencies that should not change the functionality of Flower. However, if you +want to ensure that you always use the same image, you can specify the hash of the image instead of +the tag. + +The following command returns the current image hash referenced by the ``server:1.7.0-py3.11-ubuntu22.04`` tag: + +.. code-block:: bash + + $ docker inspect --format='{{index .RepoDigests 0}}' flwr/server:1.7.0-py3.11-ubuntu22.04 + flwr/server@sha256:c4be5012f9d73e3022e98735a889a463bb2f4f434448ebc19c61379920b1b327 + +Next, we can pin the hash when running a new server container: + +.. code-block:: bash + + $ docker run \ + --rm flwr/server@sha256:c4be5012f9d73e3022e98735a889a463bb2f4f434448ebc19c61379920b1b327 \ + --insecure + +Setting environment variables +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To set a variable inside a Docker container, you can use the ``-e =`` flag. + +.. code-block:: bash + + $ docker run -e FLWR_TELEMETRY_ENABLED=0 \ + --rm flwr/server:1.7.0-py3.11-ubuntu22.04 --insecure diff --git a/doc/source/how-to-run-simulations.rst b/doc/source/how-to-run-simulations.rst index 707e3d3ffe84..d1dcb511ed51 100644 --- a/doc/source/how-to-run-simulations.rst +++ b/doc/source/how-to-run-simulations.rst @@ -7,7 +7,7 @@ Run simulations Simulating Federated Learning workloads is useful for a multitude of use-cases: you might want to run your workload on a large cohort of clients but without having to source, configure and mange a large number of physical devices; you might want to run your FL workloads as fast as possible on the compute systems you have access to without having to go through a complex setup process; you might want to validate your algorithm on different scenarios at varying levels of data and system heterogeneity, client availability, privacy budgets, etc. These are among some of the use-cases where simulating FL workloads makes sense. Flower can accommodate these scenarios by means of its `VirtualClientEngine `_ or VCE. -The :code:`VirtualClientEngine` schedules, launches and manages `virtual` clients. These clients are identical to `non-virtual` clients (i.e. the ones you launch via the command `flwr.client.start_numpy_client `_) in the sense that they can be configure by creating a class inheriting, for example, from `flwr.client.NumPyClient `_ and therefore behave in an identical way. In addition to that, clients managed by the :code:`VirtualClientEngine` are: +The :code:`VirtualClientEngine` schedules, launches and manages `virtual` clients. These clients are identical to `non-virtual` clients (i.e. the ones you launch via the command `flwr.client.start_client `_) in the sense that they can be configure by creating a class inheriting, for example, from `flwr.client.NumPyClient `_ and therefore behave in an identical way. In addition to that, clients managed by the :code:`VirtualClientEngine` are: * resource-aware: this means that each client gets assigned a portion of the compute and memory on your system. You as a user can control this at the beginning of the simulation and allows you to control the degree of parallelism of your Flower FL simulation. The fewer the resources per client, the more clients can run concurrently on the same hardware. * self-managed: this means that you as a user do not need to launch clients manually, instead this gets delegated to :code:`VirtualClientEngine`'s internals. @@ -29,7 +29,7 @@ Running Flower simulations still require you to define your client class, a stra def client_fn(cid: str): # Return a standard Flower client - return MyFlowerClient() + return MyFlowerClient().to_client() # Launch the simulation hist = fl.simulation.start_simulation( diff --git a/doc/source/how-to-save-and-load-model-checkpoints.rst b/doc/source/how-to-save-and-load-model-checkpoints.rst index 404df485fbae..0d711e375cd8 100644 --- a/doc/source/how-to-save-and-load-model-checkpoints.rst +++ b/doc/source/how-to-save-and-load-model-checkpoints.rst @@ -91,3 +91,7 @@ To load your progress, you simply append the following lines to your code. Note print("Loading pre-trained model from: ", latest_round_file) state_dict = torch.load(latest_round_file) net.load_state_dict(state_dict) + state_dict_ndarrays = [v.cpu().numpy() for v in net.state_dict().values()] + parameters = fl.common.ndarrays_to_parameters(state_dict_ndarrays) + +Return/use this object of type ``Parameters`` wherever necessary, such as in the ``initial_parameters`` when defining a ``Strategy``. \ No newline at end of file diff --git a/doc/source/how-to-upgrade-to-flower-1.0.rst b/doc/source/how-to-upgrade-to-flower-1.0.rst index fd380e95d69c..3a55a1a953f5 100644 --- a/doc/source/how-to-upgrade-to-flower-1.0.rst +++ b/doc/source/how-to-upgrade-to-flower-1.0.rst @@ -50,7 +50,7 @@ Strategies / ``start_server`` / ``start_simulation`` - Replace ``num_rounds=1`` in ``start_simulation`` with the new ``config=ServerConfig(...)`` (see previous item) - Remove ``force_final_distributed_eval`` parameter from calls to ``start_server``. Distributed evaluation on all clients can be enabled by configuring the strategy to sample all clients for evaluation after the last round of training. - Rename parameter/ndarray conversion functions: - + - ``parameters_to_weights`` --> ``parameters_to_ndarrays`` - ``weights_to_parameters`` --> ``ndarrays_to_parameters`` @@ -81,11 +81,11 @@ Optional improvements Along with the necessary changes above, there are a number of potential improvements that just became possible: -- Remove "placeholder" methods from subclasses of ``Client`` or ``NumPyClient``. If you, for example, use server-side evaluation, then empy placeholder implementations of ``evaluate`` are no longer necessary. +- Remove "placeholder" methods from subclasses of ``Client`` or ``NumPyClient``. If you, for example, use server-side evaluation, then empty placeholder implementations of ``evaluate`` are no longer necessary. - Configure the round timeout via ``start_simulation``: ``start_simulation(..., config=flwr.server.ServerConfig(num_rounds=3, round_timeout=600.0), ...)`` Further help ------------ -Most official `Flower code examples `_ are already updated to Flower 1.0, they can serve as a reference for using the Flower 1.0 API. If there are further questionsm, `join the Flower Slack `_ and use the channgel ``#questions``. +Most official `Flower code examples `_ are already updated to Flower 1.0, they can serve as a reference for using the Flower 1.0 API. If there are further questions, `join the Flower Slack `_ and use the channel ``#questions``. diff --git a/doc/source/how-to-use-built-in-mods.rst b/doc/source/how-to-use-built-in-mods.rst new file mode 100644 index 000000000000..341139175074 --- /dev/null +++ b/doc/source/how-to-use-built-in-mods.rst @@ -0,0 +1,89 @@ +Use Built-in Mods +================= + +**Note: This tutorial covers experimental features. The functionality and interfaces may change in future versions.** + +In this tutorial, we will learn how to utilize built-in mods to augment the behavior of a ``ClientApp``. Mods (sometimes also called Modifiers) allow us to perform operations before and after a task is processed in the ``ClientApp``. + +What are Mods? +-------------- + +A Mod is a callable that wraps around a ``ClientApp``. It can manipulate or inspect the incoming ``Message`` and the resulting outgoing ``Message``. The signature for a ``Mod`` is as follows: + +.. code-block:: python + + ClientApp = Callable[[Message, Context], Message] + Mod = Callable[[Message, Context, ClientApp], Message] + +A typical mod function might look something like this: + +.. code-block:: python + + def example_mod(msg: Message, ctx: Context, nxt: ClientApp) -> Message: + # Do something with incoming Message (or Context) + # before passing to the inner ``ClientApp`` + msg = nxt(msg, ctx) + # Do something with outgoing Message (or Context) + # before returning + return msg + +Using Mods +---------- + +To use mods in your ``ClientApp``, you can follow these steps: + +1. Import the required mods +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +First, import the built-in mod you intend to use: + +.. code-block:: python + + import flwr as fl + from flwr.client.mod import example_mod_1, example_mod_2 + +2. Define your client function +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Define your client function (``client_fn``) that will be wrapped by the mod(s): + +.. code-block:: python + + def client_fn(cid): + # Your client code goes here. + return # your client + +3. Create the ``ClientApp`` with mods +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Create your ``ClientApp`` and pass the mods as a list to the ``mods`` argument. The order in which you provide the mods matters: + +.. code-block:: python + + app = fl.client.ClientApp( + client_fn=client_fn, + mods=[ + example_mod_1, # Mod 1 + example_mod_2, # Mod 2 + ] + ) + +Order of execution +------------------ + +When the ``ClientApp`` runs, the mods are executed in the order they are provided in the list: + +1. ``example_mod_1`` (outermost mod) +2. ``example_mod_2`` (next mod) +3. Message handler (core function that handles the incoming ``Message`` and returns the outgoing ``Message``) +4. ``example_mod_2`` (on the way back) +5. ``example_mod_1`` (outermost mod on the way back) + +Each mod has a chance to inspect and modify the incoming ``Message`` before passing it to the next mod, and likewise with the outgoing ``Message`` before returning it up the stack. + +Conclusion +---------- + +By following this guide, you have learned how to effectively use mods to enhance your ``ClientApp``'s functionality. Remember that the order of mods is crucial and affects how the input and output are processed. + +Enjoy building a more robust and flexible ``ClientApp`` with mods! diff --git a/doc/source/how-to-use-differential-privacy.rst b/doc/source/how-to-use-differential-privacy.rst new file mode 100644 index 000000000000..c8901bd906cc --- /dev/null +++ b/doc/source/how-to-use-differential-privacy.rst @@ -0,0 +1,126 @@ +Use Differential Privacy +------------------------ +This guide explains how you can utilize differential privacy in the Flower framework. If you are not yet familiar with differential privacy, you can refer to :doc:`explanation-differential-privacy`. + +.. warning:: + + Differential Privacy in Flower is in a preview phase. If you plan to use these features in a production environment with sensitive data, feel free contact us to discuss your requirements and to receive guidance on how to best use these features. + + +Central Differential Privacy +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +This approach consists of two seprate phases: clipping of the updates and adding noise to the aggregated model. +For the clipping phase, Flower framework has made it possible to decide whether to perform clipping on the server side or the client side. + +- **Server-side Clipping**: This approach has the advantage of the server enforcing uniform clipping across all clients' updates and reducing the communication overhead for clipping values. However, it also has the disadvantage of increasing the computational load on the server due to the need to perform the clipping operation for all clients. +- **Client-side Clipping**: This approach has the advantage of reducing the computational overhead on the server. However, it also has the disadvantage of lacking centralized control, as the server has less control over the clipping process. + + + +Server-side Clipping +^^^^^^^^^^^^^^^^^^^^ +For central DP with server-side clipping, there are two :code:`Strategy` classes that act as wrappers around the actual :code:`Strategy` instance (for example, :code:`FedAvg`). +The two wrapper classes are :code:`DifferentialPrivacyServerSideFixedClipping` and :code:`DifferentialPrivacyServerSideAdaptiveClipping` for fixed and adaptive clipping. + +.. image:: ./_static/DP/serversideCDP.png + :align: center + :width: 700 + :alt: server side clipping + + +The code sample below enables the :code:`FedAvg` strategy to use server-side fixed clipping using the :code:`DifferentialPrivacyServerSideFixedClipping` wrapper class. +The same approach can be used with :code:`DifferentialPrivacyServerSideAdaptiveClipping` by adjusting the corresponding input parameters. + +.. code-block:: python + + from flwr.server.strategy import DifferentialPrivacyClientSideFixedClipping + + # Create the strategy + strategy = fl.server.strategy.FedAvg(...) + + # Wrap the strategy with the DifferentialPrivacyServerSideFixedClipping wrapper + dp_strategy = DifferentialPrivacyServerSideFixedClipping( + strategy, + cfg.noise_multiplier, + cfg.clipping_norm, + cfg.num_sampled_clients, + ) + + + +Client-side Clipping +^^^^^^^^^^^^^^^^^^^^ +For central DP with client-side clipping, the server sends the clipping value to selected clients on each round. +Clients can use existing Flower :code:`Mods` to perform the clipping. +Two mods are available for fixed and adaptive client-side clipping: :code:`fixedclipping_mod` and :code:`adaptiveclipping_mod` with corresponding server-side wrappers :code:`DifferentialPrivacyClientSideFixedClipping` and :code:`DifferentialPrivacyClientSideAdaptiveClipping`. + +.. image:: ./_static/DP/clientsideCDP.png + :align: center + :width: 800 + :alt: client side clipping + + +The code sample below enables the :code:`FedAvg` strategy to use differential privacy with client-side fixed clipping using both the :code:`DifferentialPrivacyClientSideFixedClipping` wrapper class and, on the client, :code:`fixedclipping_mod`: + +.. code-block:: python + + from flwr.server.strategy import DifferentialPrivacyClientSideFixedClipping + + # Create the strategy + strategy = fl.server.strategy.FedAvg(...) + + # Wrap the strategy with the DifferentialPrivacyClientSideFixedClipping wrapper + dp_strategy = DifferentialPrivacyClientSideFixedClipping( + strategy, + cfg.noise_multiplier, + cfg.clipping_norm, + cfg.num_sampled_clients, + ) + +In addition to the server-side strategy wrapper, the :code:`ClientApp` needs to configure the matching :code:`fixedclipping_mod` to perform the client-side clipping: + +.. code-block:: python + + from flwr.client.mod import fixedclipping_mod + + # Add fixedclipping_mod to the client-side mods + app = fl.client.ClientApp( + client_fn=client_fn, + mods=[ + fixedclipping_mod, + ] + ) + + +Local Differential Privacy +~~~~~~~~~~~~~~~~~~~~~~~~~~ +To utilize local differential privacy (DP) and add noise to the client model parameters before transmitting them to the server in Flower, you can use the `LocalDpMod`. The following hyperparameters need to be set: clipping norm value, sensitivity, epsilon, and delta. + +.. image:: ./_static/DP/localdp.png + :align: center + :width: 700 + :alt: local DP mod + +Below is a code example that shows how to use :code:`LocalDpMod`: + +.. code-block:: python + + from flwr.client.mod.localdp_mod import LocalDpMod + + # Create an instance of the mod with the required params + local_dp_obj = LocalDpMod( + cfg.clipping_norm, cfg.sensitivity, cfg.epsilon, cfg.delta + ) + # Add local_dp_obj to the client-side mods + + app = fl.client.ClientApp( + client_fn=client_fn, + mods=[local_dp_obj], + ) + + +Please note that the order of mods, especially those that modify parameters, is important when using multiple modifiers. Typically, differential privacy (DP) modifiers should be the last to operate on parameters. + +Local Training using Privacy Engines +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +For ensuring data instance-level privacy during local model training on the client side, consider leveraging privacy engines such as Opacus and TensorFlow Privacy. For examples of using Flower with these engines, please refer to the Flower examples directory (`Opacus `_, `Tensorflow Privacy `_). \ No newline at end of file diff --git a/doc/source/how-to-use-strategies.rst b/doc/source/how-to-use-strategies.rst index 6d24f97bd7f6..d0e2cd63a091 100644 --- a/doc/source/how-to-use-strategies.rst +++ b/doc/source/how-to-use-strategies.rst @@ -45,7 +45,7 @@ Configuring client fit and client evaluate ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The server can pass new configuration values to the client each round by providing a function to :code:`on_fit_config_fn`. The provided function will be called by the strategy and must return a dictionary of configuration key values pairs that will be sent to the client. -It must return a dictionary of arbitraty configuration values :code:`client.fit` and :code:`client.evaluate` functions during each round of federated learning. +It must return a dictionary of arbitrary configuration values :code:`client.fit` and :code:`client.evaluate` functions during each round of federated learning. .. code-block:: python diff --git a/doc/source/index.rst b/doc/source/index.rst index c4a313414d3a..894155be03f1 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -4,7 +4,7 @@ Flower Framework Documentation .. meta:: :description: Check out the documentation of the main Flower Framework enabling easy Python development for Federated Learning. -Welcome to Flower's documentation. `Flower `_ is a friendly federated learning framework. +Welcome to Flower's documentation. `Flower `_ is a friendly federated learning framework. Join the Flower Community @@ -12,7 +12,7 @@ Join the Flower Community The Flower Community is growing quickly - we're a friendly group of researchers, engineers, students, professionals, academics, and other enthusiasts. -.. button-link:: https://flower.dev/join-slack +.. button-link:: https://flower.ai/join-slack :color: primary :shadow: @@ -91,6 +91,9 @@ Problem-oriented how-to guides show step-by-step how to achieve a specific goal. how-to-configure-logging how-to-enable-ssl-connections how-to-upgrade-to-flower-1.0 + how-to-use-built-in-mods + how-to-run-flower-using-docker + how-to-use-differential-privacy .. toctree:: :maxdepth: 1 @@ -119,11 +122,17 @@ References Information-oriented API reference and other reference material. +.. autosummary:: + :toctree: ref-api + :template: autosummary/module.rst + :caption: API reference + :recursive: + + flwr + .. toctree:: :maxdepth: 2 - :caption: API reference - ref-api-flwr ref-api-cli .. toctree:: @@ -160,6 +169,7 @@ The Flower community welcomes contributions. The following docs are intended to contributor-how-to-write-documentation contributor-how-to-release-flower contributor-how-to-contribute-translations + contributor-how-to-build-docker-images .. toctree:: :maxdepth: 1 diff --git a/doc/source/ref-api-cli.rst b/doc/source/ref-api-cli.rst index 039a2ea27cf8..63579143755d 100644 --- a/doc/source/ref-api-cli.rst +++ b/doc/source/ref-api-cli.rst @@ -1,42 +1,52 @@ Flower CLI reference ==================== -.. _flower-server-apiref: +.. _flower-superlink-apiref: -flower-server -~~~~~~~~~~~~~ +flower-superlink +~~~~~~~~~~~~~~~~ .. argparse:: :module: flwr.server.app - :func: _parse_args_server - :prog: flower-server + :func: _parse_args_run_superlink + :prog: flower-superlink -.. _flower-driver-apiref: +.. _flower-driver-api-apiref: flower-driver-api ~~~~~~~~~~~~~~~~~ .. argparse:: :module: flwr.server.app - :func: _parse_args_driver + :func: _parse_args_run_driver_api :prog: flower-driver-api -.. _flower-fleet-apiref: +.. _flower-fleet-api-apiref: flower-fleet-api ~~~~~~~~~~~~~~~~ .. argparse:: :module: flwr.server.app - :func: _parse_args_fleet + :func: _parse_args_run_fleet_api :prog: flower-fleet-api -.. .. _flower-client-apiref: +.. _flower-client-app-apiref: + +flower-client-app +~~~~~~~~~~~~~~~~~ + +.. argparse:: + :module: flwr.client.app + :func: _parse_args_run_client_app + :prog: flower-client-app + +.. _flower-server-app-apiref: -.. flower-client -.. ~~~~~~~~~~~~~ +flower-server-app +~~~~~~~~~~~~~~~~~ - .. argparse:: -.. :filename: flwr.client -.. :func: run_client -.. :prog: flower-client +.. argparse:: + :module: flwr.server.run_serverapp + :func: _parse_args_run_server_app + :prog: flower-server-app diff --git a/doc/source/ref-api-flwr.rst b/doc/source/ref-api-flwr.rst deleted file mode 100644 index e1983cd92c90..000000000000 --- a/doc/source/ref-api-flwr.rst +++ /dev/null @@ -1,265 +0,0 @@ -flwr (Python API reference) -=========================== - - -.. _flwr-client-apiref: - -client ------- - -.. automodule:: flwr.client - -.. _flwr-client-Client-apiref: - -Client -~~~~~~ - -.. autoclass:: flwr.client.Client - :members: - - -.. _flwr-client-start_client-apiref: - -start_client -~~~~~~~~~~~~ - -.. autofunction:: flwr.client.start_client - - -.. _flwr-client-NumPyClient-apiref: - -NumPyClient -~~~~~~~~~~~ - -.. autoclass:: flwr.client.NumPyClient - :members: - - -.. _flwr-client-start_numpy_client-apiref: - -start_numpy_client -~~~~~~~~~~~~~~~~~~ - -.. autofunction:: flwr.client.start_numpy_client - - -.. _flwr-simulation-start_simulation-apiref: - -start_simulation -~~~~~~~~~~~~~~~~~~ - -.. autofunction:: flwr.simulation.start_simulation - - -.. _flwr-server-apiref: - -server ------- - -.. automodule:: flwr.server - - -.. _flwr-server-start_server-apiref: - -server.start_server -~~~~~~~~~~~~~~~~~~~ - -.. autofunction:: flwr.server.start_server - - -.. _flwr-server-strategy-apiref: - -server.strategy -~~~~~~~~~~~~~~~ - -.. automodule:: flwr.server.strategy - - -.. _flwr-server-strategy-Strategy-apiref: - -server.strategy.Strategy -^^^^^^^^^^^^^^^^^^^^^^^^ - -.. autoclass:: flwr.server.strategy.Strategy - :members: - - -.. _flwr-server-strategy-FedAvg-apiref: - -server.strategy.FedAvg -^^^^^^^^^^^^^^^^^^^^^^ - -.. autoclass:: flwr.server.strategy.FedAvg - :members: - - .. automethod:: __init__ - - -.. _flwr-server-strategy-FedAvgM-apiref: - -server.strategy.FedAvgM -^^^^^^^^^^^^^^^^^^^^^^^ - -.. autoclass:: flwr.server.strategy.FedAvgM - :members: - - .. automethod:: __init__ - - -.. _flwr-server-strategy-FedMedian-apiref: - -server.strategy.FedMedian -^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. autoclass:: flwr.server.strategy.FedMedian - :members: - - .. automethod:: __init__ - -.. _flwr-server-strategy-QFedAvg-apiref: - -server.strategy.QFedAvg -^^^^^^^^^^^^^^^^^^^^^^^ - -.. autoclass:: flwr.server.strategy.QFedAvg - :members: - - .. automethod:: __init__ - - -.. _flwr-server-strategy-FaultTolerantFedAvg-apiref: - -server.strategy.FaultTolerantFedAvg -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. autoclass:: flwr.server.strategy.FaultTolerantFedAvg - :members: - - .. automethod:: __init__ - - -.. _flwr-server-strategy-FedOpt-apiref: - -server.strategy.FedOpt -^^^^^^^^^^^^^^^^^^^^^^ - -.. autoclass:: flwr.server.strategy.FedOpt - :members: - - .. automethod:: __init__ - - -.. _flwr-server-strategy-FedProx-apiref: - -server.strategy.FedProx -^^^^^^^^^^^^^^^^^^^^^^^ - -.. autoclass:: flwr.server.strategy.FedProx - :members: - - .. automethod:: __init__ - - -.. _flwr-server-strategy-FedAdagrad-apiref: - -server.strategy.FedAdagrad -^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. autoclass:: flwr.server.strategy.FedAdagrad - :members: - - .. automethod:: __init__ - - -.. _flwr-server-strategy-FedAdam-apiref: - -server.strategy.FedAdam -^^^^^^^^^^^^^^^^^^^^^^^ - -.. autoclass:: flwr.server.strategy.FedAdam - :members: - - .. automethod:: __init__ - - -.. _flwr-server-strategy-FedYogi-apiref: - -server.strategy.FedYogi -^^^^^^^^^^^^^^^^^^^^^^^ - -.. autoclass:: flwr.server.strategy.FedYogi - :members: - - .. automethod:: __init__ - - -.. _flwr-server-strategy-FedTrimmedAvg-apiref: - -server.strategy.FedTrimmedAvg -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. autoclass:: flwr.server.strategy.FedTrimmedAvg - :members: - - .. automethod:: __init__ - - -.. _flwr-server-strategy-Krum-apiref: - -server.strategy.Krum -^^^^^^^^^^^^^^^^^^^^ - -.. autoclass:: flwr.server.strategy.Krum - :members: - - .. automethod:: __init__ - -.. _flwr-server-strategy-Bulyan-apiref: - -server.strategy.Bulyan -^^^^^^^^^^^^^^^^^^^^^^ - -.. autoclass:: flwr.server.strategy.Bulyan - :members: - - .. automethod:: __init__ - - -.. _flwr-server-strategy-FedXgbNnAvg-apiref: - -server.strategy.FedXgbNnAvg -^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. autoclass:: flwr.server.strategy.FedXgbNnAvg - :members: - - .. automethod:: __init__ - - -.. _flwr-server-strategy-DPFedAvgAdaptive-apiref: - -server.strategy.DPFedAvgAdaptive -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. autoclass:: flwr.server.strategy.DPFedAvgAdaptive - :members: - - .. automethod:: __init__ - - -.. _flwr-server-strategy-DPFedAvgFixed-apiref: - -server.strategy.DPFedAvgFixed -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. autoclass:: flwr.server.strategy.DPFedAvgFixed - :members: - - .. automethod:: __init__ - -common ------- - -.. automodule:: flwr.common - :members: - :exclude-members: event diff --git a/doc/source/ref-changelog.md b/doc/source/ref-changelog.md index 030f7618f4b2..1a6524d29353 100644 --- a/doc/source/ref-changelog.md +++ b/doc/source/ref-changelog.md @@ -2,11 +2,102 @@ ## Unreleased -- **General updates to Flower Examples** ([#2381](https://github.com/adap/flower/pull/2381)) +### What's new? + +### Incompatible changes + +## v1.7.0 (2024-02-05) + +### Thanks to our contributors + +We would like to give our special thanks to all the contributors who made the new version of Flower possible (in `git shortlog` order): + +`Aasheesh Singh`, `Adam Narozniak`, `Aml Hassan Esmil`, `Charles Beauville`, `Daniel J. Beutel`, `Daniel Nata Nugraha`, `Edoardo Gabrielli`, `Gustavo Bertoli`, `HelinLin`, `Heng Pan`, `Javier`, `M S Chaitanya Kumar`, `Mohammad Naseri`, `Nikos Vlachakis`, `Pritam Neog`, `Robert Kuska`, `Robert Steiner`, `Taner Topal`, `Yahia Salaheldin Shaaban`, `Yan Gao`, `Yasar Abbas` + +### What's new? + +- **Introduce stateful clients (experimental)** ([#2770](https://github.com/adap/flower/pull/2770), [#2686](https://github.com/adap/flower/pull/2686), [#2696](https://github.com/adap/flower/pull/2696), [#2643](https://github.com/adap/flower/pull/2643), [#2769](https://github.com/adap/flower/pull/2769)) + + Subclasses of `Client` and `NumPyClient` can now store local state that remains on the client. Let's start with the highlight first: this new feature is compatible with both simulated clients (via `start_simulation`) and networked clients (via `start_client`). It's also the first preview of new abstractions like `Context` and `RecordSet`. Clients can access state of type `RecordSet` via `state: RecordSet = self.context.state`. Changes to this `RecordSet` are preserved across different rounds of execution to enable stateful computations in a unified way across simulation and deployment. + +- **Improve performance** ([#2293](https://github.com/adap/flower/pull/2293)) + + Flower is faster than ever. All `FedAvg`-derived strategies now use in-place aggregation to reduce memory consumption. The Flower client serialization/deserialization has been rewritten from the ground up, which results in significant speedups, especially when the client-side training time is short. + +- **Support Federated Learning with Apple MLX and Flower** ([#2693](https://github.com/adap/flower/pull/2693)) + + Flower has official support for federated learning using [Apple MLX](https://ml-explore.github.io/mlx) via the new `quickstart-mlx` code example. + +- **Introduce new XGBoost cyclic strategy** ([#2666](https://github.com/adap/flower/pull/2666), [#2668](https://github.com/adap/flower/pull/2668)) + + A new strategy called `FedXgbCyclic` supports a client-by-client style of training (often called cyclic). The `xgboost-comprehensive` code example shows how to use it in a full project. In addition to that, `xgboost-comprehensive` now also supports simulation mode. With this, Flower offers best-in-class XGBoost support. + +- **Support Python 3.11** ([#2394](https://github.com/adap/flower/pull/2394)) + + Framework tests now run on Python 3.8, 3.9, 3.10, and 3.11. This will ensure better support for users using more recent Python versions. + +- **Update gRPC and ProtoBuf dependencies** ([#2814](https://github.com/adap/flower/pull/2814)) + + The `grpcio` and `protobuf` dependencies were updated to their latest versions for improved security and performance. + +- **Introduce Docker image for Flower server** ([#2700](https://github.com/adap/flower/pull/2700), [#2688](https://github.com/adap/flower/pull/2688), [#2705](https://github.com/adap/flower/pull/2705), [#2695](https://github.com/adap/flower/pull/2695), [#2747](https://github.com/adap/flower/pull/2747), [#2746](https://github.com/adap/flower/pull/2746), [#2680](https://github.com/adap/flower/pull/2680), [#2682](https://github.com/adap/flower/pull/2682), [#2701](https://github.com/adap/flower/pull/2701)) + + The Flower server can now be run using an official Docker image. A new how-to guide explains [how to run Flower using Docker](https://flower.ai/docs/framework/how-to-run-flower-using-docker.html). An official Flower client Docker image will follow. + +- **Introduce** `flower-via-docker-compose` **example** ([#2626](https://github.com/adap/flower/pull/2626)) + +- **Introduce** `quickstart-sklearn-tabular` **example** ([#2719](https://github.com/adap/flower/pull/2719)) + +- **Introduce** `custom-metrics` **example** ([#1958](https://github.com/adap/flower/pull/1958)) + +- **Update code examples to use Flower Datasets** ([#2450](https://github.com/adap/flower/pull/2450), [#2456](https://github.com/adap/flower/pull/2456), [#2318](https://github.com/adap/flower/pull/2318), [#2712](https://github.com/adap/flower/pull/2712)) + + Several code examples were updated to use [Flower Datasets](https://flower.ai/docs/datasets/). + +- **General updates to Flower Examples** ([#2381](https://github.com/adap/flower/pull/2381), [#2805](https://github.com/adap/flower/pull/2805), [#2782](https://github.com/adap/flower/pull/2782), [#2806](https://github.com/adap/flower/pull/2806), [#2829](https://github.com/adap/flower/pull/2829), [#2825](https://github.com/adap/flower/pull/2825), [#2816](https://github.com/adap/flower/pull/2816), [#2726](https://github.com/adap/flower/pull/2726), [#2659](https://github.com/adap/flower/pull/2659), [#2655](https://github.com/adap/flower/pull/2655)) + + Many Flower code examples received substantial updates. - **Update Flower Baselines** - - HFedXGBoost [#2226](https://github.com/adap/flower/pull/2226) + - HFedXGBoost ([#2226](https://github.com/adap/flower/pull/2226), [#2771](https://github.com/adap/flower/pull/2771)) + - FedVSSL ([#2412](https://github.com/adap/flower/pull/2412)) + - FedNova ([#2179](https://github.com/adap/flower/pull/2179)) + - HeteroFL ([#2439](https://github.com/adap/flower/pull/2439)) + - FedAvgM ([#2246](https://github.com/adap/flower/pull/2246)) + - FedPara ([#2722](https://github.com/adap/flower/pull/2722)) + +- **Improve documentation** ([#2674](https://github.com/adap/flower/pull/2674), [#2480](https://github.com/adap/flower/pull/2480), [#2826](https://github.com/adap/flower/pull/2826), [#2727](https://github.com/adap/flower/pull/2727), [#2761](https://github.com/adap/flower/pull/2761), [#2900](https://github.com/adap/flower/pull/2900)) + +- **Improved testing and development infrastructure** ([#2797](https://github.com/adap/flower/pull/2797), [#2676](https://github.com/adap/flower/pull/2676), [#2644](https://github.com/adap/flower/pull/2644), [#2656](https://github.com/adap/flower/pull/2656), [#2848](https://github.com/adap/flower/pull/2848), [#2675](https://github.com/adap/flower/pull/2675), [#2735](https://github.com/adap/flower/pull/2735), [#2767](https://github.com/adap/flower/pull/2767), [#2732](https://github.com/adap/flower/pull/2732), [#2744](https://github.com/adap/flower/pull/2744), [#2681](https://github.com/adap/flower/pull/2681), [#2699](https://github.com/adap/flower/pull/2699), [#2745](https://github.com/adap/flower/pull/2745), [#2734](https://github.com/adap/flower/pull/2734), [#2731](https://github.com/adap/flower/pull/2731), [#2652](https://github.com/adap/flower/pull/2652), [#2720](https://github.com/adap/flower/pull/2720), [#2721](https://github.com/adap/flower/pull/2721), [#2717](https://github.com/adap/flower/pull/2717), [#2864](https://github.com/adap/flower/pull/2864), [#2694](https://github.com/adap/flower/pull/2694), [#2709](https://github.com/adap/flower/pull/2709), [#2658](https://github.com/adap/flower/pull/2658), [#2796](https://github.com/adap/flower/pull/2796), [#2692](https://github.com/adap/flower/pull/2692), [#2657](https://github.com/adap/flower/pull/2657), [#2813](https://github.com/adap/flower/pull/2813), [#2661](https://github.com/adap/flower/pull/2661), [#2398](https://github.com/adap/flower/pull/2398)) + + The Flower testing and development infrastructure has received substantial updates. This makes Flower 1.7 the most tested release ever. + +- **Update dependencies** ([#2753](https://github.com/adap/flower/pull/2753), [#2651](https://github.com/adap/flower/pull/2651), [#2739](https://github.com/adap/flower/pull/2739), [#2837](https://github.com/adap/flower/pull/2837), [#2788](https://github.com/adap/flower/pull/2788), [#2811](https://github.com/adap/flower/pull/2811), [#2774](https://github.com/adap/flower/pull/2774), [#2790](https://github.com/adap/flower/pull/2790), [#2751](https://github.com/adap/flower/pull/2751), [#2850](https://github.com/adap/flower/pull/2850), [#2812](https://github.com/adap/flower/pull/2812), [#2872](https://github.com/adap/flower/pull/2872), [#2736](https://github.com/adap/flower/pull/2736), [#2756](https://github.com/adap/flower/pull/2756), [#2857](https://github.com/adap/flower/pull/2857), [#2757](https://github.com/adap/flower/pull/2757), [#2810](https://github.com/adap/flower/pull/2810), [#2740](https://github.com/adap/flower/pull/2740), [#2789](https://github.com/adap/flower/pull/2789)) + +- **General improvements** ([#2803](https://github.com/adap/flower/pull/2803), [#2847](https://github.com/adap/flower/pull/2847), [#2877](https://github.com/adap/flower/pull/2877), [#2690](https://github.com/adap/flower/pull/2690), [#2889](https://github.com/adap/flower/pull/2889), [#2874](https://github.com/adap/flower/pull/2874), [#2819](https://github.com/adap/flower/pull/2819), [#2689](https://github.com/adap/flower/pull/2689), [#2457](https://github.com/adap/flower/pull/2457), [#2870](https://github.com/adap/flower/pull/2870), [#2669](https://github.com/adap/flower/pull/2669), [#2876](https://github.com/adap/flower/pull/2876), [#2885](https://github.com/adap/flower/pull/2885), [#2858](https://github.com/adap/flower/pull/2858), [#2867](https://github.com/adap/flower/pull/2867), [#2351](https://github.com/adap/flower/pull/2351), [#2886](https://github.com/adap/flower/pull/2886), [#2860](https://github.com/adap/flower/pull/2860), [#2828](https://github.com/adap/flower/pull/2828), [#2869](https://github.com/adap/flower/pull/2869), [#2875](https://github.com/adap/flower/pull/2875), [#2733](https://github.com/adap/flower/pull/2733), [#2488](https://github.com/adap/flower/pull/2488), [#2646](https://github.com/adap/flower/pull/2646), [#2879](https://github.com/adap/flower/pull/2879), [#2821](https://github.com/adap/flower/pull/2821), [#2855](https://github.com/adap/flower/pull/2855), [#2800](https://github.com/adap/flower/pull/2800), [#2807](https://github.com/adap/flower/pull/2807), [#2801](https://github.com/adap/flower/pull/2801), [#2804](https://github.com/adap/flower/pull/2804), [#2851](https://github.com/adap/flower/pull/2851), [#2787](https://github.com/adap/flower/pull/2787), [#2852](https://github.com/adap/flower/pull/2852), [#2672](https://github.com/adap/flower/pull/2672), [#2759](https://github.com/adap/flower/pull/2759)) + +### Incompatible changes + +- **Deprecate** `start_numpy_client` ([#2563](https://github.com/adap/flower/pull/2563), [#2718](https://github.com/adap/flower/pull/2718)) + + Until now, clients of type `NumPyClient` needed to be started via `start_numpy_client`. In our efforts to consolidate framework APIs, we have introduced changes, and now all client types should start via `start_client`. To continue using `NumPyClient` clients, you simply need to first call the `.to_client()` method and then pass returned `Client` object to `start_client`. The examples and the documentation have been updated accordingly. + +- **Deprecate legacy DP wrappers** ([#2749](https://github.com/adap/flower/pull/2749)) + + Legacy DP wrapper classes are deprecated, but still functional. This is in preparation for an all-new pluggable version of differential privacy support in Flower. + +- **Make optional arg** `--callable` **in** `flower-client` **a required positional arg** ([#2673](https://github.com/adap/flower/pull/2673)) + +- **Rename** `certificates` **to** `root_certificates` **in** `Driver` ([#2890](https://github.com/adap/flower/pull/2890)) + +- **Drop experimental** `Task` **fields** ([#2866](https://github.com/adap/flower/pull/2866), [#2865](https://github.com/adap/flower/pull/2865)) + + Experimental fields `sa`, `legacy_server_message` and `legacy_client_message` were removed from `Task` message. The removed fields are superseded by the new `RecordSet` abstraction. + +- **Retire MXNet examples** ([#2724](https://github.com/adap/flower/pull/2724)) + + The development of the MXNet fremework has ended and the project is now [archived on GitHub](https://github.com/apache/mxnet). Existing MXNet examples won't receive updates. ## v1.6.0 (2023-11-28) @@ -86,7 +177,7 @@ We would like to give our special thanks to all the contributors who made the ne - FedBN ([#2608](https://github.com/adap/flower/pull/2608), [#2615](https://github.com/adap/flower/pull/2615)) -- **General updates to Flower Examples** ([#2384](https://github.com/adap/flower/pull/2384),[#2425](https://github.com/adap/flower/pull/2425), [#2526](https://github.com/adap/flower/pull/2526), [#2302](https://github.com/adap/flower/pull/2302), [#2545](https://github.com/adap/flower/pull/2545)) +- **General updates to Flower Examples** ([#2384](https://github.com/adap/flower/pull/2384), [#2425](https://github.com/adap/flower/pull/2425), [#2526](https://github.com/adap/flower/pull/2526), [#2302](https://github.com/adap/flower/pull/2302), [#2545](https://github.com/adap/flower/pull/2545)) - **General updates to Flower Baselines** ([#2301](https://github.com/adap/flower/pull/2301), [#2305](https://github.com/adap/flower/pull/2305), [#2307](https://github.com/adap/flower/pull/2307), [#2327](https://github.com/adap/flower/pull/2327), [#2435](https://github.com/adap/flower/pull/2435), [#2462](https://github.com/adap/flower/pull/2462), [#2463](https://github.com/adap/flower/pull/2463), [#2461](https://github.com/adap/flower/pull/2461), [#2469](https://github.com/adap/flower/pull/2469), [#2466](https://github.com/adap/flower/pull/2466), [#2471](https://github.com/adap/flower/pull/2471), [#2472](https://github.com/adap/flower/pull/2472), [#2470](https://github.com/adap/flower/pull/2470)) @@ -94,7 +185,7 @@ We would like to give our special thanks to all the contributors who made the ne - **General updates to Flower SDKs** ([#2288](https://github.com/adap/flower/pull/2288), [#2429](https://github.com/adap/flower/pull/2429), [#2555](https://github.com/adap/flower/pull/2555), [#2543](https://github.com/adap/flower/pull/2543), [#2544](https://github.com/adap/flower/pull/2544), [#2597](https://github.com/adap/flower/pull/2597), [#2623](https://github.com/adap/flower/pull/2623)) -- **General improvements** ([#2309](https://github.com/adap/flower/pull/2309), [#2310](https://github.com/adap/flower/pull/2310), [2313](https://github.com/adap/flower/pull/2313), [#2316](https://github.com/adap/flower/pull/2316), [2317](https://github.com/adap/flower/pull/2317), [#2349](https://github.com/adap/flower/pull/2349), [#2360](https://github.com/adap/flower/pull/2360), [#2402](https://github.com/adap/flower/pull/2402), [#2446](https://github.com/adap/flower/pull/2446), [#2561](https://github.com/adap/flower/pull/2561), [#2273](https://github.com/adap/flower/pull/2273), [#2267](https://github.com/adap/flower/pull/2267), [#2274](https://github.com/adap/flower/pull/2274), [#2275](https://github.com/adap/flower/pull/2275), [#2432](https://github.com/adap/flower/pull/2432), [#2251](https://github.com/adap/flower/pull/2251), [#2321](https://github.com/adap/flower/pull/2321), [#1936](https://github.com/adap/flower/pull/1936), [#2408](https://github.com/adap/flower/pull/2408), [#2413](https://github.com/adap/flower/pull/2413), [#2401](https://github.com/adap/flower/pull/2401), [#2531](https://github.com/adap/flower/pull/2531), [#2534](https://github.com/adap/flower/pull/2534), [#2535](https://github.com/adap/flower/pull/2535), [#2521](https://github.com/adap/flower/pull/2521), [#2553](https://github.com/adap/flower/pull/2553), [#2596](https://github.com/adap/flower/pull/2596)) +- **General improvements** ([#2309](https://github.com/adap/flower/pull/2309), [#2310](https://github.com/adap/flower/pull/2310), [#2313](https://github.com/adap/flower/pull/2313), [#2316](https://github.com/adap/flower/pull/2316), [#2317](https://github.com/adap/flower/pull/2317), [#2349](https://github.com/adap/flower/pull/2349), [#2360](https://github.com/adap/flower/pull/2360), [#2402](https://github.com/adap/flower/pull/2402), [#2446](https://github.com/adap/flower/pull/2446), [#2561](https://github.com/adap/flower/pull/2561), [#2273](https://github.com/adap/flower/pull/2273), [#2267](https://github.com/adap/flower/pull/2267), [#2274](https://github.com/adap/flower/pull/2274), [#2275](https://github.com/adap/flower/pull/2275), [#2432](https://github.com/adap/flower/pull/2432), [#2251](https://github.com/adap/flower/pull/2251), [#2321](https://github.com/adap/flower/pull/2321), [#1936](https://github.com/adap/flower/pull/1936), [#2408](https://github.com/adap/flower/pull/2408), [#2413](https://github.com/adap/flower/pull/2413), [#2401](https://github.com/adap/flower/pull/2401), [#2531](https://github.com/adap/flower/pull/2531), [#2534](https://github.com/adap/flower/pull/2534), [#2535](https://github.com/adap/flower/pull/2535), [#2521](https://github.com/adap/flower/pull/2521), [#2553](https://github.com/adap/flower/pull/2553), [#2596](https://github.com/adap/flower/pull/2596)) Flower received many improvements under the hood, too many to list here. @@ -122,11 +213,11 @@ We would like to give our special thanks to all the contributors who made the ne The new simulation engine has been rewritten from the ground up, yet it remains fully backwards compatible. It offers much improved stability and memory handling, especially when working with GPUs. Simulations transparently adapt to different settings to scale simulation in CPU-only, CPU+GPU, multi-GPU, or multi-node multi-GPU environments. - Comprehensive documentation includes a new [how-to run simulations](https://flower.dev/docs/framework/how-to-run-simulations.html) guide, new [simulation-pytorch](https://flower.dev/docs/examples/simulation-pytorch.html) and [simulation-tensorflow](https://flower.dev/docs/examples/simulation-tensorflow.html) notebooks, and a new [YouTube tutorial series](https://www.youtube.com/watch?v=cRebUIGB5RU&list=PLNG4feLHqCWlnj8a_E1A_n5zr2-8pafTB). + Comprehensive documentation includes a new [how-to run simulations](https://flower.ai/docs/framework/how-to-run-simulations.html) guide, new [simulation-pytorch](https://flower.ai/docs/examples/simulation-pytorch.html) and [simulation-tensorflow](https://flower.ai/docs/examples/simulation-tensorflow.html) notebooks, and a new [YouTube tutorial series](https://www.youtube.com/watch?v=cRebUIGB5RU&list=PLNG4feLHqCWlnj8a_E1A_n5zr2-8pafTB). - **Restructure Flower Docs** ([#1824](https://github.com/adap/flower/pull/1824), [#1865](https://github.com/adap/flower/pull/1865), [#1884](https://github.com/adap/flower/pull/1884), [#1887](https://github.com/adap/flower/pull/1887), [#1919](https://github.com/adap/flower/pull/1919), [#1922](https://github.com/adap/flower/pull/1922), [#1920](https://github.com/adap/flower/pull/1920), [#1923](https://github.com/adap/flower/pull/1923), [#1924](https://github.com/adap/flower/pull/1924), [#1962](https://github.com/adap/flower/pull/1962), [#2006](https://github.com/adap/flower/pull/2006), [#2133](https://github.com/adap/flower/pull/2133), [#2203](https://github.com/adap/flower/pull/2203), [#2215](https://github.com/adap/flower/pull/2215), [#2122](https://github.com/adap/flower/pull/2122), [#2223](https://github.com/adap/flower/pull/2223), [#2219](https://github.com/adap/flower/pull/2219), [#2232](https://github.com/adap/flower/pull/2232), [#2233](https://github.com/adap/flower/pull/2233), [#2234](https://github.com/adap/flower/pull/2234), [#2235](https://github.com/adap/flower/pull/2235), [#2237](https://github.com/adap/flower/pull/2237), [#2238](https://github.com/adap/flower/pull/2238), [#2242](https://github.com/adap/flower/pull/2242), [#2231](https://github.com/adap/flower/pull/2231), [#2243](https://github.com/adap/flower/pull/2243), [#2227](https://github.com/adap/flower/pull/2227)) - Much effort went into a completely restructured Flower docs experience. The documentation on [flower.dev/docs](flower.dev/docs) is now divided into Flower Framework, Flower Baselines, Flower Android SDK, Flower iOS SDK, and code example projects. + Much effort went into a completely restructured Flower docs experience. The documentation on [flower.ai/docs](https://flower.ai/docs) is now divided into Flower Framework, Flower Baselines, Flower Android SDK, Flower iOS SDK, and code example projects. - **Introduce Flower Swift SDK** ([#1858](https://github.com/adap/flower/pull/1858), [#1897](https://github.com/adap/flower/pull/1897)) @@ -204,7 +295,7 @@ We would like to give our special thanks to all the contributors who made the ne - **Introduce support for XGBoost (**`FedXgbNnAvg` **strategy and example)** ([#1694](https://github.com/adap/flower/pull/1694), [#1709](https://github.com/adap/flower/pull/1709), [#1715](https://github.com/adap/flower/pull/1715), [#1717](https://github.com/adap/flower/pull/1717), [#1763](https://github.com/adap/flower/pull/1763), [#1795](https://github.com/adap/flower/pull/1795)) - XGBoost is a tree-based ensemble machine learning algorithm that uses gradient boosting to improve model accuracy. We added a new `FedXgbNnAvg` [strategy](https://github.com/adap/flower/tree/main/src/py/flwr/server/strategy/fedxgb_nn_avg.py), and a [code example](https://github.com/adap/flower/tree/main/examples/quickstart_xgboost_horizontal) that demonstrates the usage of this new strategy in an XGBoost project. + XGBoost is a tree-based ensemble machine learning algorithm that uses gradient boosting to improve model accuracy. We added a new `FedXgbNnAvg` [strategy](https://github.com/adap/flower/tree/main/src/py/flwr/server/strategy/fedxgb_nn_avg.py), and a [code example](https://github.com/adap/flower/tree/main/examples/xgboost-quickstart) that demonstrates the usage of this new strategy in an XGBoost project. - **Introduce iOS SDK (preview)** ([#1621](https://github.com/adap/flower/pull/1621), [#1764](https://github.com/adap/flower/pull/1764)) @@ -212,11 +303,11 @@ We would like to give our special thanks to all the contributors who made the ne - **Introduce new "What is Federated Learning?" tutorial** ([#1657](https://github.com/adap/flower/pull/1657), [#1721](https://github.com/adap/flower/pull/1721)) - A new [entry-level tutorial](https://flower.dev/docs/framework/tutorial-what-is-federated-learning.html) in our documentation explains the basics of Fedetated Learning. It enables anyone who's unfamiliar with Federated Learning to start their journey with Flower. Forward it to anyone who's interested in Federated Learning! + A new [entry-level tutorial](https://flower.ai/docs/framework/tutorial-what-is-federated-learning.html) in our documentation explains the basics of Fedetated Learning. It enables anyone who's unfamiliar with Federated Learning to start their journey with Flower. Forward it to anyone who's interested in Federated Learning! - **Introduce new Flower Baseline: FedProx MNIST** ([#1513](https://github.com/adap/flower/pull/1513), [#1680](https://github.com/adap/flower/pull/1680), [#1681](https://github.com/adap/flower/pull/1681), [#1679](https://github.com/adap/flower/pull/1679)) - This new baseline replicates the MNIST+CNN task from the paper [Federated Optimization in Heterogeneous Networks (Li et al., 2018)](https://arxiv.org/abs/1812.06127). It uses the `FedProx` strategy, which aims at making convergence more robust in heterogenous settings. + This new baseline replicates the MNIST+CNN task from the paper [Federated Optimization in Heterogeneous Networks (Li et al., 2018)](https://arxiv.org/abs/1812.06127). It uses the `FedProx` strategy, which aims at making convergence more robust in heterogeneous settings. - **Introduce new Flower Baseline: FedAvg FEMNIST** ([#1655](https://github.com/adap/flower/pull/1655)) @@ -238,7 +329,7 @@ We would like to give our special thanks to all the contributors who made the ne - **Add new example using** `TabNet` **and Flower** ([#1725](https://github.com/adap/flower/pull/1725)) - TabNet is a powerful and flexible framework for training machine learning models on tabular data. We now have a federated example using Flower: [https://github.com/adap/flower/tree/main/examples/tabnet](https://github.com/adap/flower/tree/main/examples/quickstart_tabnet). + TabNet is a powerful and flexible framework for training machine learning models on tabular data. We now have a federated example using Flower: [quickstart-tabnet](https://github.com/adap/flower/tree/main/examples/quickstart-tabnet). - **Add new how-to guide for monitoring simulations** ([#1649](https://github.com/adap/flower/pull/1649)) @@ -280,7 +371,7 @@ We would like to give our special thanks to all the contributors who made the ne - **Add new example of Federated Learning using fastai and Flower** ([#1598](https://github.com/adap/flower/pull/1598)) - A new code example (`quickstart_fastai`) demonstrates federated learning with [fastai](https://www.fast.ai/) and Flower. You can find it here: [quickstart_fastai](https://github.com/adap/flower/tree/main/examples/quickstart_fastai). + A new code example (`quickstart-fastai`) demonstrates federated learning with [fastai](https://www.fast.ai/) and Flower. You can find it here: [quickstart-fastai](https://github.com/adap/flower/tree/main/examples/quickstart-fastai). - **Make Android example compatible with** `flwr >= 1.0.0` **and the latest versions of Android** ([#1603](https://github.com/adap/flower/pull/1603)) @@ -326,7 +417,7 @@ We would like to give our special thanks to all the contributors who made the ne - **Introduce new Flower Baseline: FedAvg MNIST** ([#1497](https://github.com/adap/flower/pull/1497), [#1552](https://github.com/adap/flower/pull/1552)) - Over the coming weeks, we will be releasing a number of new reference implementations useful especially to FL newcomers. They will typically revisit well known papers from the literature, and be suitable for integration in your own application or for experimentation, in order to deepen your knowledge of FL in general. Today's release is the first in this series. [Read more.](https://flower.dev/blog/2023-01-12-fl-starter-pack-fedavg-mnist-cnn/) + Over the coming weeks, we will be releasing a number of new reference implementations useful especially to FL newcomers. They will typically revisit well known papers from the literature, and be suitable for integration in your own application or for experimentation, in order to deepen your knowledge of FL in general. Today's release is the first in this series. [Read more.](https://flower.ai/blog/2023-01-12-fl-starter-pack-fedavg-mnist-cnn/) - **Improve GPU support in simulations** ([#1555](https://github.com/adap/flower/pull/1555)) @@ -336,16 +427,16 @@ We would like to give our special thanks to all the contributors who made the ne Some users reported that Jupyter Notebooks have not always been easy to use on GPU instances. We listened and made improvements to all of our Jupyter notebooks! Check out the updated notebooks here: - - [An Introduction to Federated Learning](https://flower.dev/docs/framework/tutorial-get-started-with-flower-pytorch.html) - - [Strategies in Federated Learning](https://flower.dev/docs/framework/tutorial-use-a-federated-learning-strategy-pytorch.html) - - [Building a Strategy](https://flower.dev/docs/framework/tutorial-build-a-strategy-from-scratch-pytorch.html) - - [Client and NumPyClient](https://flower.dev/docs/framework/tutorial-customize-the-client-pytorch.html) + - [An Introduction to Federated Learning](https://flower.ai/docs/framework/tutorial-get-started-with-flower-pytorch.html) + - [Strategies in Federated Learning](https://flower.ai/docs/framework/tutorial-use-a-federated-learning-strategy-pytorch.html) + - [Building a Strategy](https://flower.ai/docs/framework/tutorial-build-a-strategy-from-scratch-pytorch.html) + - [Client and NumPyClient](https://flower.ai/docs/framework/tutorial-customize-the-client-pytorch.html) - **Introduce optional telemetry** ([#1533](https://github.com/adap/flower/pull/1533), [#1544](https://github.com/adap/flower/pull/1544), [#1584](https://github.com/adap/flower/pull/1584)) After a [request for feedback](https://github.com/adap/flower/issues/1534) from the community, the Flower open-source project introduces optional collection of *anonymous* usage metrics to make well-informed decisions to improve Flower. Doing this enables the Flower team to understand how Flower is used and what challenges users might face. - **Flower is a friendly framework for collaborative AI and data science.** Staying true to this statement, Flower makes it easy to disable telemetry for users who do not want to share anonymous usage metrics. [Read more.](https://flower.dev/docs/telemetry.html). + **Flower is a friendly framework for collaborative AI and data science.** Staying true to this statement, Flower makes it easy to disable telemetry for users who do not want to share anonymous usage metrics. [Read more.](https://flower.ai/docs/telemetry.html). - **Introduce (experimental) Driver API** ([#1520](https://github.com/adap/flower/pull/1520), [#1525](https://github.com/adap/flower/pull/1525), [#1545](https://github.com/adap/flower/pull/1545), [#1546](https://github.com/adap/flower/pull/1546), [#1550](https://github.com/adap/flower/pull/1550), [#1551](https://github.com/adap/flower/pull/1551), [#1567](https://github.com/adap/flower/pull/1567)) @@ -359,7 +450,7 @@ We would like to give our special thanks to all the contributors who made the ne - **Add new Federated Analytics with Pandas example** ([#1469](https://github.com/adap/flower/pull/1469), [#1535](https://github.com/adap/flower/pull/1535)) - A new code example (`quickstart_pandas`) demonstrates federated analytics with Pandas and Flower. You can find it here: [quickstart_pandas](https://github.com/adap/flower/tree/main/examples/quickstart_pandas). + A new code example (`quickstart-pandas`) demonstrates federated analytics with Pandas and Flower. You can find it here: [quickstart-pandas](https://github.com/adap/flower/tree/main/examples/quickstart-pandas). - **Add new strategies: Krum and MultiKrum** ([#1481](https://github.com/adap/flower/pull/1481)) @@ -377,7 +468,7 @@ We would like to give our special thanks to all the contributors who made the ne As usual, the documentation has improved quite a bit. It is another step in our effort to make the Flower documentation the best documentation of any project. Stay tuned and as always, feel free to provide feedback! - One highlight is the new [first time contributor guide](https://flower.dev/docs/first-time-contributors.html): if you've never contributed on GitHub before, this is the perfect place to start! + One highlight is the new [first time contributor guide](https://flower.ai/docs/first-time-contributors.html): if you've never contributed on GitHub before, this is the perfect place to start! ### Incompatible changes @@ -458,7 +549,7 @@ None We would like to give our **special thanks** to all the contributors who made Flower 1.0 possible (in reverse [GitHub Contributors](https://github.com/adap/flower/graphs/contributors) order): -[@rtaiello](https://github.com/rtaiello), [@g-pichler](https://github.com/g-pichler), [@rob-luke](https://github.com/rob-luke), [@andreea-zaharia](https://github.com/andreea-zaharia), [@kinshukdua](https://github.com/kinshukdua), [@nfnt](https://github.com/nfnt), [@tatiana-s](https://github.com/tatiana-s), [@TParcollet](https://github.com/TParcollet), [@vballoli](https://github.com/vballoli), [@negedng](https://github.com/negedng), [@RISHIKESHAVAN](https://github.com/RISHIKESHAVAN), [@hei411](https://github.com/hei411), [@SebastianSpeitel](https://github.com/SebastianSpeitel), [@AmitChaulwar](https://github.com/AmitChaulwar), [@Rubiel1](https://github.com/Rubiel1), [@FANTOME-PAN](https://github.com/FANTOME-PAN), [@Rono-BC](https://github.com/Rono-BC), [@lbhm](https://github.com/lbhm), [@sishtiaq](https://github.com/sishtiaq), [@remde](https://github.com/remde), [@Jueun-Park](https://github.com/Jueun-Park), [@architjen](https://github.com/architjen), [@PratikGarai](https://github.com/PratikGarai), [@mrinaald](https://github.com/mrinaald), [@zliel](https://github.com/zliel), [@MeiruiJiang](https://github.com/MeiruiJiang), [@sandracl72](https://github.com/sandracl72), [@gubertoli](https://github.com/gubertoli), [@Vingt100](https://github.com/Vingt100), [@MakGulati](https://github.com/MakGulati), [@cozek](https://github.com/cozek), [@jafermarq](https://github.com/jafermarq), [@sisco0](https://github.com/sisco0), [@akhilmathurs](https://github.com/akhilmathurs), [@CanTuerk](https://github.com/CanTuerk), [@mariaboerner1987](https://github.com/mariaboerner1987), [@pedropgusmao](https://github.com/pedropgusmao), [@tanertopal](https://github.com/tanertopal), [@danieljanes](https://github.com/danieljanes). +[@rtaiello](https://github.com/rtaiello), [@g-pichler](https://github.com/g-pichler), [@rob-luke](https://github.com/rob-luke), [@andreea-zaharia](https://github.com/andreea-zaharia), [@kinshukdua](https://github.com/kinshukdua), [@nfnt](https://github.com/nfnt), [@tatiana-s](https://github.com/tatiana-s), [@TParcollet](https://github.com/TParcollet), [@vballoli](https://github.com/vballoli), [@negedng](https://github.com/negedng), [@RISHIKESHAVAN](https://github.com/RISHIKESHAVAN), [@hei411](https://github.com/hei411), [@SebastianSpeitel](https://github.com/SebastianSpeitel), [@AmitChaulwar](https://github.com/AmitChaulwar), [@Rubiel1](https://github.com/Rubiel1), [@FANTOME-PAN](https://github.com/FANTOME-PAN), [@Rono-BC](https://github.com/Rono-BC), [@lbhm](https://github.com/lbhm), [@sishtiaq](https://github.com/sishtiaq), [@remde](https://github.com/remde), [@Jueun-Park](https://github.com/Jueun-Park), [@architjen](https://github.com/architjen), [@PratikGarai](https://github.com/PratikGarai), [@mrinaald](https://github.com/mrinaald), [@zliel](https://github.com/zliel), [@MeiruiJiang](https://github.com/MeiruiJiang), [@sancarlim](https://github.com/sancarlim), [@gubertoli](https://github.com/gubertoli), [@Vingt100](https://github.com/Vingt100), [@MakGulati](https://github.com/MakGulati), [@cozek](https://github.com/cozek), [@jafermarq](https://github.com/jafermarq), [@sisco0](https://github.com/sisco0), [@akhilmathurs](https://github.com/akhilmathurs), [@CanTuerk](https://github.com/CanTuerk), [@mariaboerner1987](https://github.com/mariaboerner1987), [@pedropgusmao](https://github.com/pedropgusmao), [@tanertopal](https://github.com/tanertopal), [@danieljanes](https://github.com/danieljanes). ### Incompatible changes @@ -566,7 +657,7 @@ We would like to give our **special thanks** to all the contributors who made Fl - **Flower Baselines (preview): FedOpt, FedBN, FedAvgM** ([#919](https://github.com/adap/flower/pull/919), [#1127](https://github.com/adap/flower/pull/1127), [#914](https://github.com/adap/flower/pull/914)) - The first preview release of Flower Baselines has arrived! We're kickstarting Flower Baselines with implementations of FedOpt (FedYogi, FedAdam, FedAdagrad), FedBN, and FedAvgM. Check the documentation on how to use [Flower Baselines](https://flower.dev/docs/using-baselines.html). With this first preview release we're also inviting the community to [contribute their own baselines](https://flower.dev/docs/contributing-baselines.html). + The first preview release of Flower Baselines has arrived! We're kickstarting Flower Baselines with implementations of FedOpt (FedYogi, FedAdam, FedAdagrad), FedBN, and FedAvgM. Check the documentation on how to use [Flower Baselines](https://flower.ai/docs/using-baselines.html). With this first preview release we're also inviting the community to [contribute their own baselines](https://flower.ai/docs/baselines/how-to-contribute-baselines.html). - **C++ client SDK (preview) and code example** ([#1111](https://github.com/adap/flower/pull/1111)) @@ -612,7 +703,7 @@ We would like to give our **special thanks** to all the contributors who made Fl - New option to keep Ray running if Ray was already initialized in `start_simulation` ([#1177](https://github.com/adap/flower/pull/1177)) - Add support for custom `ClientManager` as a `start_simulation` parameter ([#1171](https://github.com/adap/flower/pull/1171)) - - New documentation for [implementing strategies](https://flower.dev/docs/framework/how-to-implement-strategies.html) ([#1097](https://github.com/adap/flower/pull/1097), [#1175](https://github.com/adap/flower/pull/1175)) + - New documentation for [implementing strategies](https://flower.ai/docs/framework/how-to-implement-strategies.html) ([#1097](https://github.com/adap/flower/pull/1097), [#1175](https://github.com/adap/flower/pull/1175)) - New mobile-friendly documentation theme ([#1174](https://github.com/adap/flower/pull/1174)) - Limit version range for (optional) `ray` dependency to include only compatible releases (`>=1.9.2,<1.12.0`) ([#1205](https://github.com/adap/flower/pull/1205)) @@ -732,7 +823,7 @@ We would like to give our **special thanks** to all the contributors who made Fl - **Renamed q-FedAvg strategy** ([#802](https://github.com/adap/flower/pull/802)) - The strategy named `QffedAvg` was renamed to `QFedAvg` to better reflect the notation given in the original paper (q-FFL is the optimization objective, q-FedAvg is the proposed solver). Note the the original (now deprecated) `QffedAvg` class is still available for compatibility reasons (it will be removed in a future release). + The strategy named `QffedAvg` was renamed to `QFedAvg` to better reflect the notation given in the original paper (q-FFL is the optimization objective, q-FedAvg is the proposed solver). Note the original (now deprecated) `QffedAvg` class is still available for compatibility reasons (it will be removed in a future release). - **Deprecated and renamed code example** `simulation_pytorch` **to** `simulation_pytorch_legacy` ([#791](https://github.com/adap/flower/pull/791)) @@ -751,9 +842,9 @@ We would like to give our **special thanks** to all the contributors who made Fl The Flower server is now fully task-agnostic, all remaining instances of task-specific metrics (such as `accuracy`) have been replaced by custom metrics dictionaries. Flower 0.15 introduced the capability to pass a dictionary containing custom metrics from client to server. As of this release, custom metrics replace task-specific metrics on the server. - Custom metric dictionaries are now used in two user-facing APIs: they are returned from Strategy methods `aggregate_fit`/`aggregate_evaluate` and they enable evaluation functions passed to build-in strategies (via `eval_fn`) to return more than two evaluation metrics. Strategies can even return *aggregated* metrics dictionaries for the server to keep track of. + Custom metric dictionaries are now used in two user-facing APIs: they are returned from Strategy methods `aggregate_fit`/`aggregate_evaluate` and they enable evaluation functions passed to built-in strategies (via `eval_fn`) to return more than two evaluation metrics. Strategies can even return *aggregated* metrics dictionaries for the server to keep track of. - Stratey implementations should migrate their `aggregate_fit` and `aggregate_evaluate` methods to the new return type (e.g., by simply returning an empty `{}`), server-side evaluation functions should migrate from `return loss, accuracy` to `return loss, {"accuracy": accuracy}`. + Strategy implementations should migrate their `aggregate_fit` and `aggregate_evaluate` methods to the new return type (e.g., by simply returning an empty `{}`), server-side evaluation functions should migrate from `return loss, accuracy` to `return loss, {"accuracy": accuracy}`. Flower 0.15-style return types are deprecated (but still supported), compatibility will be removed in a future release. @@ -773,7 +864,7 @@ We would like to give our **special thanks** to all the contributors who made Fl The Flower server is now fully serialization-agnostic. Prior usage of class `Weights` (which represents parameters as deserialized NumPy ndarrays) was replaced by class `Parameters` (e.g., in `Strategy`). `Parameters` objects are fully serialization-agnostic and represents parameters as byte arrays, the `tensor_type` attributes indicates how these byte arrays should be interpreted (e.g., for serialization/deserialization). - Built-in strategies implement this approach by handling serialization and deserialization to/from `Weights` internally. Custom/3rd-party Strategy implementations should update to the slighly changed Strategy method definitions. Strategy authors can consult PR [#721](https://github.com/adap/flower/pull/721) to see how strategies can easily migrate to the new format. + Built-in strategies implement this approach by handling serialization and deserialization to/from `Weights` internally. Custom/3rd-party Strategy implementations should update to the slightly changed Strategy method definitions. Strategy authors can consult PR [#721](https://github.com/adap/flower/pull/721) to see how strategies can easily migrate to the new format. - Deprecated `flwr.server.Server.evaluate`, use `flwr.server.Server.evaluate_round` instead ([#717](https://github.com/adap/flower/pull/717)) @@ -794,7 +885,7 @@ What's new? ) model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"]) - # Create strategy and initilize parameters on the server-side + # Create strategy and initialize parameters on the server-side strategy = fl.server.strategy.FedAvg( # ... (other constructor arguments) initial_parameters=model.get_weights(), diff --git a/doc/source/ref-example-projects.rst b/doc/source/ref-example-projects.rst index b47bd8e48997..bade86dfaa54 100644 --- a/doc/source/ref-example-projects.rst +++ b/doc/source/ref-example-projects.rst @@ -23,8 +23,8 @@ The TensorFlow/Keras quickstart example shows CIFAR-10 image classification with MobileNetV2: - `Quickstart TensorFlow (Code) `_ -- `Quickstart TensorFlow (Tutorial) `_ -- `Quickstart TensorFlow (Blog Post) `_ +- :doc:`Quickstart TensorFlow (Tutorial) ` +- `Quickstart TensorFlow (Blog Post) `_ Quickstart PyTorch @@ -34,7 +34,7 @@ The PyTorch quickstart example shows CIFAR-10 image classification with a simple Convolutional Neural Network: - `Quickstart PyTorch (Code) `_ -- `Quickstart PyTorch (Tutorial) `_ +- :doc:`Quickstart PyTorch (Tutorial) ` PyTorch: From Centralized To Federated @@ -43,7 +43,7 @@ PyTorch: From Centralized To Federated This example shows how a regular PyTorch project can be federated using Flower: - `PyTorch: From Centralized To Federated (Code) `_ -- `PyTorch: From Centralized To Federated (Tutorial) `_ +- :doc:`PyTorch: From Centralized To Federated (Tutorial) ` Federated Learning on Raspberry Pi and Nvidia Jetson @@ -52,7 +52,7 @@ Federated Learning on Raspberry Pi and Nvidia Jetson This example shows how Flower can be used to build a federated learning system that run across Raspberry Pi and Nvidia Jetson: - `Federated Learning on Raspberry Pi and Nvidia Jetson (Code) `_ -- `Federated Learning on Raspberry Pi and Nvidia Jetson (Blog Post) `_ +- `Federated Learning on Raspberry Pi and Nvidia Jetson (Blog Post) `_ @@ -60,7 +60,7 @@ Legacy Examples (`flwr_example`) -------------------------------- .. warning:: - The useage examples in `flwr_example` are deprecated and will be removed in + The usage examples in `flwr_example` are deprecated and will be removed in the future. New examples are provided as standalone projects in `examples `_. @@ -114,7 +114,7 @@ For more details, see :code:`src/py/flwr_example/pytorch_cifar`. ImageNet-2012 Image Classification ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -`ImageNet-2012 `_ is one of the major computer +`ImageNet-2012 `_ is one of the major computer vision datasets. The Flower ImageNet example uses PyTorch to train a ResNet-18 classifier in a federated learning setup with ten clients. diff --git a/doc/source/ref-faq.rst b/doc/source/ref-faq.rst index 13c44bc64b0e..26b7dca4a0a7 100644 --- a/doc/source/ref-faq.rst +++ b/doc/source/ref-faq.rst @@ -3,23 +3,23 @@ FAQ This page collects answers to commonly asked questions about Federated Learning with Flower. -.. dropdown:: :fa:`eye,mr-1` Can Flower run on Juptyter Notebooks / Google Colab? +.. dropdown:: :fa:`eye,mr-1` Can Flower run on Jupyter Notebooks / Google Colab? Yes, it can! Flower even comes with a few under-the-hood optimizations to make it work even better on Colab. Here's a quickstart example: - + * `Flower simulation PyTorch `_ * `Flower simulation TensorFlow/Keras `_ .. dropdown:: :fa:`eye,mr-1` How can I run Federated Learning on a Raspberry Pi? - Find the `blog post about federated learning on embedded device here `_ and the corresponding `GitHub code example `_. + Find the `blog post about federated learning on embedded device here `_ and the corresponding `GitHub code example `_. .. dropdown:: :fa:`eye,mr-1` Does Flower support federated learning on Android devices? - Yes, it does. Please take a look at our `blog post `_ or check out the code examples: + Yes, it does. Please take a look at our `blog post `_ or check out the code examples: - * `Android Kotlin example `_ - * `Android Java example `_ + * `Android Kotlin example `_ + * `Android Java example `_ .. dropdown:: :fa:`eye,mr-1` Can I combine federated learning with blockchain? @@ -27,6 +27,6 @@ This page collects answers to commonly asked questions about Federated Learning * `Flower meets Nevermined GitHub Repository `_. * `Flower meets Nevermined YouTube video `_. - * `Flower meets KOSMoS `_. + * `Flower meets KOSMoS `_. * `Flower meets Talan blog post `_ . * `Flower meets Talan GitHub Repository `_ . diff --git a/doc/source/ref-telemetry.md b/doc/source/ref-telemetry.md index 206e641d8b41..49efef5c8559 100644 --- a/doc/source/ref-telemetry.md +++ b/doc/source/ref-telemetry.md @@ -41,7 +41,7 @@ Flower telemetry collects the following metrics: **Source.** Flower telemetry tries to store a random source ID in `~/.flwr/source` the first time a telemetry event is generated. The source ID is important to identify whether an issue is recurring or whether an issue is triggered by multiple clusters running concurrently (which often happens in simulation). For example, if a device runs multiple workloads at the same time, and this results in an issue, then, in order to reproduce the issue, multiple workloads must be started at the same time. -You may delete the source ID at any time. If you wish for all events logged under a specific source ID to be deleted, you can send a deletion request mentioning the source ID to `telemetry@flower.dev`. All events related to that source ID will then be permanently deleted. +You may delete the source ID at any time. If you wish for all events logged under a specific source ID to be deleted, you can send a deletion request mentioning the source ID to `telemetry@flower.ai`. All events related to that source ID will then be permanently deleted. We will not collect any personally identifiable information. If you think any of the metrics collected could be misused in any way, please [get in touch with us](#how-to-contact-us). We will update this page to reflect any changes to the metrics collected and publish changes in the changelog. @@ -63,4 +63,4 @@ FLWR_TELEMETRY_ENABLED=0 FLWR_TELEMETRY_LOGGING=1 python server.py # or client.p ## How to contact us -We want to hear from you. If you have any feedback or ideas on how to improve the way we handle anonymous usage metrics, reach out to us via [Slack](https://flower.dev/join-slack/) (channel `#telemetry`) or email (`telemetry@flower.dev`). +We want to hear from you. If you have any feedback or ideas on how to improve the way we handle anonymous usage metrics, reach out to us via [Slack](https://flower.ai/join-slack/) (channel `#telemetry`) or email (`telemetry@flower.ai`). diff --git a/doc/source/tutorial-quickstart-huggingface.rst b/doc/source/tutorial-quickstart-huggingface.rst index 7718e6558456..7d8128230901 100644 --- a/doc/source/tutorial-quickstart-huggingface.rst +++ b/doc/source/tutorial-quickstart-huggingface.rst @@ -9,8 +9,8 @@ Quickstart 🤗 Transformers Let's build a federated learning system using Hugging Face Transformers and Flower! -We will leverage Hugging Face to federate the training of language models over multiple clients using Flower. -More specifically, we will fine-tune a pre-trained Transformer model (distilBERT) +We will leverage Hugging Face to federate the training of language models over multiple clients using Flower. +More specifically, we will fine-tune a pre-trained Transformer model (distilBERT) for sequence classification over a dataset of IMDB ratings. The end goal is to detect if a movie rating is positive or negative. @@ -32,8 +32,8 @@ Standard Hugging Face workflow Handling the data ^^^^^^^^^^^^^^^^^ -To fetch the IMDB dataset, we will use Hugging Face's :code:`datasets` library. -We then need to tokenize the data and create :code:`PyTorch` dataloaders, +To fetch the IMDB dataset, we will use Hugging Face's :code:`datasets` library. +We then need to tokenize the data and create :code:`PyTorch` dataloaders, this is all done in the :code:`load_data` function: .. code-block:: python @@ -80,8 +80,8 @@ this is all done in the :code:`load_data` function: Training and testing the model ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Once we have a way of creating our trainloader and testloader, -we can take care of the training and testing. +Once we have a way of creating our trainloader and testloader, +we can take care of the training and testing. This is very similar to any :code:`PyTorch` training or testing loop: .. code-block:: python @@ -120,12 +120,12 @@ This is very similar to any :code:`PyTorch` training or testing loop: Creating the model itself ^^^^^^^^^^^^^^^^^^^^^^^^^ -To create the model itself, +To create the model itself, we will just load the pre-trained distillBERT model using Hugging Face’s :code:`AutoModelForSequenceClassification` : .. code-block:: python - from transformers import AutoModelForSequenceClassification + from transformers import AutoModelForSequenceClassification net = AutoModelForSequenceClassification.from_pretrained( CHECKPOINT, num_labels=2 @@ -138,8 +138,8 @@ Federating the example Creating the IMDBClient ^^^^^^^^^^^^^^^^^^^^^^^ -To federate our example to multiple clients, -we first need to write our Flower client class (inheriting from :code:`flwr.client.NumPyClient`). +To federate our example to multiple clients, +we first need to write our Flower client class (inheriting from :code:`flwr.client.NumPyClient`). This is very easy, as our model is a standard :code:`PyTorch` model: .. code-block:: python @@ -166,17 +166,17 @@ This is very easy, as our model is a standard :code:`PyTorch` model: return float(loss), len(testloader), {"accuracy": float(accuracy)} -The :code:`get_parameters` function lets the server get the client's parameters. -Inversely, the :code:`set_parameters` function allows the server to send its parameters to the client. -Finally, the :code:`fit` function trains the model locally for the client, -and the :code:`evaluate` function tests the model locally and returns the relevant metrics. +The :code:`get_parameters` function lets the server get the client's parameters. +Inversely, the :code:`set_parameters` function allows the server to send its parameters to the client. +Finally, the :code:`fit` function trains the model locally for the client, +and the :code:`evaluate` function tests the model locally and returns the relevant metrics. Starting the server ^^^^^^^^^^^^^^^^^^^ -Now that we have a way to instantiate clients, we need to create our server in order to aggregate the results. -Using Flower, this can be done very easily by first choosing a strategy (here, we are using :code:`FedAvg`, -which will define the global weights as the average of all the clients' weights at each round) +Now that we have a way to instantiate clients, we need to create our server in order to aggregate the results. +Using Flower, this can be done very easily by first choosing a strategy (here, we are using :code:`FedAvg`, +which will define the global weights as the average of all the clients' weights at each round) and then using the :code:`flwr.server.start_server` function: .. code-block:: python @@ -186,7 +186,7 @@ and then using the :code:`flwr.server.start_server` function: losses = [num_examples * m["loss"] for num_examples, m in metrics] examples = [num_examples for num_examples, _ in metrics] return {"accuracy": sum(accuracies) / sum(examples), "loss": sum(losses) / sum(examples)} - + # Define strategy strategy = fl.server.strategy.FedAvg( fraction_fit=1.0, @@ -202,7 +202,7 @@ and then using the :code:`flwr.server.start_server` function: ) -The :code:`weighted_average` function is there to provide a way to aggregate the metrics distributed amongst +The :code:`weighted_average` function is there to provide a way to aggregate the metrics distributed amongst the clients (basically this allows us to display a nice average accuracy and loss for every round). Putting everything together @@ -212,19 +212,18 @@ We can now start client instances using: .. code-block:: python - fl.client.start_numpy_client( - server_address="127.0.0.1:8080", - client=IMDBClient() + fl.client.start_client( + server_address="127.0.0.1:8080", + client=IMDBClient().to_client() ) And they will be able to connect to the server and start the federated training. -If you want to check out everything put together, -you should check out the full code example: -[https://github.com/adap/flower/tree/main/examples/quickstart-huggingface](https://github.com/adap/flower/tree/main/examples/quickstart-huggingface). +If you want to check out everything put together, +you should check out the `full code example `_ . -Of course, this is a very basic example, and a lot can be added or modified, +Of course, this is a very basic example, and a lot can be added or modified, it was just to showcase how simply we could federate a Hugging Face workflow using Flower. Note that in this example we used :code:`PyTorch`, but we could have very well used :code:`TensorFlow`. diff --git a/doc/source/tutorial-quickstart-ios.rst b/doc/source/tutorial-quickstart-ios.rst index 7c8007baaa75..e4315ce569fb 100644 --- a/doc/source/tutorial-quickstart-ios.rst +++ b/doc/source/tutorial-quickstart-ios.rst @@ -7,14 +7,14 @@ Quickstart iOS .. meta:: :description: Read this Federated Learning quickstart tutorial for creating an iOS app using Flower to train a neural network on MNIST. -In this tutorial we will learn how to train a Neural Network on MNIST using Flower and CoreML on iOS devices. +In this tutorial we will learn how to train a Neural Network on MNIST using Flower and CoreML on iOS devices. -First of all, for running the Flower Python server, it is recommended to create a virtual environment and run everything within a `virtualenv `_. +First of all, for running the Flower Python server, it is recommended to create a virtual environment and run everything within a :doc:`virtualenv `. For the Flower client implementation in iOS, it is recommended to use Xcode as our IDE. -Our example consists of one Python *server* and two iPhone *clients* that all have the same model. +Our example consists of one Python *server* and two iPhone *clients* that all have the same model. -*Clients* are responsible for generating individual weight updates for the model based on their local datasets. +*Clients* are responsible for generating individual weight updates for the model based on their local datasets. These updates are then sent to the *server* which will aggregate them to produce a better model. Finally, the *server* sends this improved version of the model back to each *client*. A complete cycle of weight updates is called a *round*. @@ -44,10 +44,10 @@ For simplicity reasons we will use the complete Flower client with CoreML, that public func getParameters() -> GetParametersRes { let parameters = parameters.weightsToParameters() let status = Status(code: .ok, message: String()) - + return GetParametersRes(parameters: parameters, status: status) } - + /// Calls the routine to fit the local model /// /// - Returns: The result from the local training, e.g., updated parameters @@ -55,17 +55,17 @@ For simplicity reasons we will use the complete Flower client with CoreML, that let status = Status(code: .ok, message: String()) let result = runMLTask(configuration: parameters.parametersToWeights(parameters: ins.parameters), task: .train) let parameters = parameters.weightsToParameters() - + return FitRes(parameters: parameters, numExamples: result.numSamples, status: status) } - + /// Calls the routine to evaluate the local model /// /// - Returns: The result from the evaluation, e.g., loss public func evaluate(ins: EvaluateIns) -> EvaluateRes { let status = Status(code: .ok, message: String()) let result = runMLTask(configuration: parameters.parametersToWeights(parameters: ins.parameters), task: .test) - + return EvaluateRes(loss: Float(result.loss), numExamples: result.numSamples, status: status) } @@ -88,18 +88,18 @@ For the MNIST dataset, we need to preprocess it into :code:`MLBatchProvider` obj // prepare train dataset let trainBatchProvider = DataLoader.trainBatchProvider() { _ in } - + // prepare test dataset let testBatchProvider = DataLoader.testBatchProvider() { _ in } - + // load them together - let dataLoader = MLDataLoader(trainBatchProvider: trainBatchProvider, + let dataLoader = MLDataLoader(trainBatchProvider: trainBatchProvider, testBatchProvider: testBatchProvider) Since CoreML does not allow the model parameters to be seen before training, and accessing the model parameters during or after the training can only be done by specifying the layer name, -we need to know this informations beforehand, through looking at the model specification, which are written as proto files. The implementation can be seen in :code:`MLModelInspect`. +we need to know this information beforehand, through looking at the model specification, which are written as proto files. The implementation can be seen in :code:`MLModelInspect`. -After we have all of the necessary informations, let's create our Flower client. +After we have all of the necessary information, let's create our Flower client. .. code-block:: swift @@ -122,7 +122,7 @@ Then start the Flower gRPC client and start communicating to the server by passi self.flwrGRPC.startFlwrGRPC(client: self.mlFlwrClient) That's it for the client. We only have to implement :code:`Client` or call the provided -:code:`MLFlwrClient` and call :code:`startFlwrGRPC()`. The attribute :code:`hostname` and :code:`port` tells the client which server to connect to. +:code:`MLFlwrClient` and call :code:`startFlwrGRPC()`. The attribute :code:`hostname` and :code:`port` tells the client which server to connect to. This can be done by entering the hostname and port in the application before clicking the start button to start the federated learning process. Flower Server diff --git a/doc/source/tutorial-quickstart-jax.rst b/doc/source/tutorial-quickstart-jax.rst index 945f231e112e..d2b9243e2bb3 100644 --- a/doc/source/tutorial-quickstart-jax.rst +++ b/doc/source/tutorial-quickstart-jax.rst @@ -265,7 +265,7 @@ Having defined the federation process, we can run it. # Start Flower client client = FlowerClient(params, grad_fn, train_x, train_y, test_x, test_y) - fl.client.start_numpy_client(server_address="0.0.0.0:8080", client) + fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client()) if __name__ == "__main__": main() diff --git a/doc/source/tutorial-quickstart-mxnet.rst b/doc/source/tutorial-quickstart-mxnet.rst index 149d060e4c00..fe582f793280 100644 --- a/doc/source/tutorial-quickstart-mxnet.rst +++ b/doc/source/tutorial-quickstart-mxnet.rst @@ -4,16 +4,18 @@ Quickstart MXNet ================ +.. warning:: MXNet is no longer maintained and has been moved into `Attic `_. As a result, we would encourage you to use other ML frameworks alongside Flower, for example, PyTorch. This tutorial might be removed in future versions of Flower. + .. meta:: :description: Check out this Federated Learning quickstart tutorial for using Flower with MXNet to train a Sequential model on MNIST. -In this tutorial, we will learn how to train a :code:`Sequential` model on MNIST using Flower and MXNet. +In this tutorial, we will learn how to train a :code:`Sequential` model on MNIST using Flower and MXNet. -It is recommended to create a virtual environment and run everything within this `virtualenv `_. +It is recommended to create a virtual environment and run everything within this :doc:`virtualenv `. -Our example consists of one *server* and two *clients* all having the same model. +Our example consists of one *server* and two *clients* all having the same model. -*Clients* are responsible for generating individual model parameter updates for the model based on their local datasets. +*Clients* are responsible for generating individual model parameter updates for the model based on their local datasets. These updates are then sent to the *server* which will aggregate them to produce an updated global model. Finally, the *server* sends this improved version of the model back to each *client*. A complete cycle of parameters updates is called a *round*. @@ -33,12 +35,12 @@ Since we want to use MXNet, let's go ahead and install it: Flower Client ------------- -Now that we have all our dependencies installed, let's run a simple distributed training with two clients and one server. Our training procedure and network architecture are based on MXNet´s `Hand-written Digit Recognition tutorial `_. +Now that we have all our dependencies installed, let's run a simple distributed training with two clients and one server. Our training procedure and network architecture are based on MXNet´s `Hand-written Digit Recognition tutorial `_. In a file called :code:`client.py`, import Flower and MXNet related packages: .. code-block:: python - + import flwr as fl import numpy as np @@ -56,7 +58,7 @@ In addition, define the device allocation in MXNet with: DEVICE = [mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()] -We use MXNet to load MNIST, a popular image classification dataset of handwritten digits for machine learning. The MXNet utility :code:`mx.test_utils.get_mnist()` downloads the training and test data. +We use MXNet to load MNIST, a popular image classification dataset of handwritten digits for machine learning. The MXNet utility :code:`mx.test_utils.get_mnist()` downloads the training and test data. .. code-block:: python @@ -70,7 +72,7 @@ We use MXNet to load MNIST, a popular image classification dataset of handwritte val_data = mx.io.NDArrayIter(mnist["test_data"], mnist["test_label"], batch_size) return train_data, val_data -Define the training and loss with MXNet. We train the model by looping over the dataset, measure the corresponding loss, and optimize it. +Define the training and loss with MXNet. We train the model by looping over the dataset, measure the corresponding loss, and optimize it. .. code-block:: python @@ -108,7 +110,7 @@ Define the training and loss with MXNet. We train the model by looping over the return trainings_metric, num_examples -Next, we define the validation of our machine learning model. We loop over the test set and measure both loss and accuracy on the test set. +Next, we define the validation of our machine learning model. We loop over the test set and measure both loss and accuracy on the test set. .. code-block:: python @@ -153,7 +155,7 @@ Our Flower clients will use a simple :code:`Sequential` model: init = nd.random.uniform(shape=(2, 784)) model(init) -After loading the dataset with :code:`load_data()` we perform one forward propagation to initialize the model and model parameters with :code:`model(init)`. Next, we implement a Flower client. +After loading the dataset with :code:`load_data()` we perform one forward propagation to initialize the model and model parameters with :code:`model(init)`. Next, we implement a Flower client. The Flower server interacts with clients through an interface called :code:`Client`. When the server selects a particular client for training, it @@ -205,7 +207,7 @@ They can be implemented in the following way: [accuracy, loss], num_examples = test(model, val_data) print("Evaluation accuracy & loss", accuracy, loss) return float(loss[1]), val_data.batch_size, {"accuracy": float(accuracy[1])} - + We can now create an instance of our class :code:`MNISTClient` and add one line to actually run this client: diff --git a/doc/source/tutorial-quickstart-pytorch.rst b/doc/source/tutorial-quickstart-pytorch.rst index fb77d107b63f..895590808a2b 100644 --- a/doc/source/tutorial-quickstart-pytorch.rst +++ b/doc/source/tutorial-quickstart-pytorch.rst @@ -10,13 +10,13 @@ Quickstart PyTorch .. youtube:: jOmmuzMIQ4c :width: 100% -In this tutorial we will learn how to train a Convolutional Neural Network on CIFAR10 using Flower and PyTorch. +In this tutorial we will learn how to train a Convolutional Neural Network on CIFAR10 using Flower and PyTorch. -First of all, it is recommended to create a virtual environment and run everything within a `virtualenv `_. +First of all, it is recommended to create a virtual environment and run everything within a :doc:`virtualenv `. -Our example consists of one *server* and two *clients* all having the same model. +Our example consists of one *server* and two *clients* all having the same model. -*Clients* are responsible for generating individual weight-updates for the model based on their local datasets. +*Clients* are responsible for generating individual weight-updates for the model based on their local datasets. These updates are then sent to the *server* which will aggregate them to produce a better model. Finally, the *server* sends this improved version of the model back to each *client*. A complete cycle of weight updates is called a *round*. @@ -26,7 +26,7 @@ Now that we have a rough idea of what is going on, let's get started. We first n $ pip install flwr -Since we want to use PyTorch to solve a computer vision task, let's go ahead and install PyTorch and the **torchvision** library: +Since we want to use PyTorch to solve a computer vision task, let's go ahead and install PyTorch and the **torchvision** library: .. code-block:: shell @@ -36,12 +36,12 @@ Since we want to use PyTorch to solve a computer vision task, let's go ahead and Flower Client ------------- -Now that we have all our dependencies installed, let's run a simple distributed training with two clients and one server. Our training procedure and network architecture are based on PyTorch's `Deep Learning with PyTorch `_. +Now that we have all our dependencies installed, let's run a simple distributed training with two clients and one server. Our training procedure and network architecture are based on PyTorch's `Deep Learning with PyTorch `_. In a file called :code:`client.py`, import Flower and PyTorch related packages: .. code-block:: python - + from collections import OrderedDict import torch @@ -59,7 +59,7 @@ In addition, we define the device allocation in PyTorch with: DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -We use PyTorch to load CIFAR10, a popular colored image classification dataset for machine learning. The PyTorch :code:`DataLoader()` downloads the training and test data that are then normalized. +We use PyTorch to load CIFAR10, a popular colored image classification dataset for machine learning. The PyTorch :code:`DataLoader()` downloads the training and test data that are then normalized. .. code-block:: python @@ -75,7 +75,7 @@ We use PyTorch to load CIFAR10, a popular colored image classification dataset f num_examples = {"trainset" : len(trainset), "testset" : len(testset)} return trainloader, testloader, num_examples -Define the loss and optimizer with PyTorch. The training of the dataset is done by looping over the dataset, measure the corresponding loss and optimize it. +Define the loss and optimizer with PyTorch. The training of the dataset is done by looping over the dataset, measure the corresponding loss and optimize it. .. code-block:: python @@ -91,7 +91,7 @@ Define the loss and optimizer with PyTorch. The training of the dataset is done loss.backward() optimizer.step() -Define then the validation of the machine learning network. We loop over the test set and measure the loss and accuracy of the test set. +Define then the validation of the machine learning network. We loop over the test set and measure the loss and accuracy of the test set. .. code-block:: python @@ -139,7 +139,7 @@ The Flower clients will use a simple CNN adapted from 'PyTorch: A 60 Minute Blit net = Net().to(DEVICE) trainloader, testloader, num_examples = load_data() -After loading the data set with :code:`load_data()` we define the Flower interface. +After loading the data set with :code:`load_data()` we define the Flower interface. The Flower server interacts with clients through an interface called :code:`Client`. When the server selects a particular client for training, it @@ -191,10 +191,10 @@ to actually run this client: .. code-block:: python - fl.client.start_numpy_client(server_address="[::]:8080", client=CifarClient()) + fl.client.start_client(server_address="[::]:8080", client=CifarClient().to_client()) That's it for the client. We only have to implement :code:`Client` or -:code:`NumPyClient` and call :code:`fl.client.start_client()` or :code:`fl.client.start_numpy_client()`. The string :code:`"[::]:8080"` tells the client which server to connect to. In our case we can run the server and the client on the same machine, therefore we use +:code:`NumPyClient` and call :code:`fl.client.start_client()`. If you implement a client of type :code:`NumPyClient` you'll need to first call its :code:`to_client()` method. The string :code:`"[::]:8080"` tells the client which server to connect to. In our case we can run the server and the client on the same machine, therefore we use :code:`"[::]:8080"`. If we run a truly federated workload with the server and clients running on different machines, all that needs to change is the :code:`server_address` we point the client at. diff --git a/doc/source/tutorial-quickstart-scikitlearn.rst b/doc/source/tutorial-quickstart-scikitlearn.rst index b33068e975fa..d1d47dc37f19 100644 --- a/doc/source/tutorial-quickstart-scikitlearn.rst +++ b/doc/source/tutorial-quickstart-scikitlearn.rst @@ -7,13 +7,13 @@ Quickstart scikit-learn .. meta:: :description: Check out this Federated Learning quickstart tutorial for using Flower with scikit-learn to train a linear regression model. -In this tutorial, we will learn how to train a :code:`Logistic Regression` model on MNIST using Flower and scikit-learn. +In this tutorial, we will learn how to train a :code:`Logistic Regression` model on MNIST using Flower and scikit-learn. -It is recommended to create a virtual environment and run everything within this `virtualenv `_. +It is recommended to create a virtual environment and run everything within this :doc:`virtualenv `. -Our example consists of one *server* and two *clients* all having the same model. +Our example consists of one *server* and two *clients* all having the same model. -*Clients* are responsible for generating individual model parameter updates for the model based on their local datasets. +*Clients* are responsible for generating individual model parameter updates for the model based on their local datasets. These updates are then sent to the *server* which will aggregate them to produce an updated global model. Finally, the *server* sends this improved version of the model back to each *client*. A complete cycle of parameters updates is called a *round*. @@ -23,7 +23,7 @@ Now that we have a rough idea of what is going on, let's get started. We first n $ pip install flwr -Since we want to use scikt-learn, let's go ahead and install it: +Since we want to use scikit-learn, let's go ahead and install it: .. code-block:: shell @@ -43,7 +43,7 @@ Now that we have all our dependencies installed, let's run a simple distributed However, before setting up the client and server, we will define all functionalities that we need for our federated learning setup within :code:`utils.py`. The :code:`utils.py` contains different functions defining all the machine learning basics: * :code:`get_model_parameters()` - * Returns the paramters of a :code:`sklearn` LogisticRegression model + * Returns the parameters of a :code:`sklearn` LogisticRegression model * :code:`set_model_params()` * Sets the parameters of a :code:`sklean` LogisticRegression model * :code:`set_initial_params()` @@ -59,7 +59,7 @@ Please check out :code:`utils.py` `here `_, a popular image classification dataset of handwritten digits for machine learning. The utility :code:`utils.load_mnist()` downloads the training and test data. The training set is split afterwards into 10 partitions with :code:`utils.partition()`. +We load the MNIST dataset from `OpenML `_, a popular image classification dataset of handwritten digits for machine learning. The utility :code:`utils.load_mnist()` downloads the training and test data. The training set is split afterwards into 10 partitions with :code:`utils.partition()`. .. code-block:: python @@ -145,10 +145,10 @@ to actually run this client: .. code-block:: python - fl.client.start_numpy_client("0.0.0.0:8080", client=MnistClient()) + fl.client.start_client("0.0.0.0:8080", client=MnistClient().to_client()) That's it for the client. We only have to implement :code:`Client` or -:code:`NumPyClient` and call :code:`fl.client.start_client()` or :code:`fl.client.start_numpy_client()`. The string :code:`"0.0.0.0:8080"` tells the client which server to connect to. In our case we can run the server and the client on the same machine, therefore we use +:code:`NumPyClient` and call :code:`fl.client.start_client()`. If you implement a client of type :code:`NumPyClient` you'll need to first call its :code:`to_client()` method. The string :code:`"0.0.0.0:8080"` tells the client which server to connect to. In our case we can run the server and the client on the same machine, therefore we use :code:`"0.0.0.0:8080"`. If we run a truly federated workload with the server and clients running on different machines, all that needs to change is the :code:`server_address` we pass to the client. diff --git a/doc/source/tutorial-quickstart-tensorflow.rst b/doc/source/tutorial-quickstart-tensorflow.rst index 64b2255a9ac6..bd63eb461d21 100644 --- a/doc/source/tutorial-quickstart-tensorflow.rst +++ b/doc/source/tutorial-quickstart-tensorflow.rst @@ -84,11 +84,11 @@ to actually run this client: .. code-block:: python - fl.client.start_numpy_client(server_address="[::]:8080", client=CifarClient()) + fl.client.start_client(server_address="[::]:8080", client=CifarClient().to_client()) That's it for the client. We only have to implement :code:`Client` or -:code:`NumPyClient` and call :code:`fl.client.start_client()` or :code:`fl.client.start_numpy_client()`. The string :code:`"[::]:8080"` tells the client which server to connect to. In our case we can run the server and the client on the same machine, therefore we use +:code:`NumPyClient` and call :code:`fl.client.start_client()`. If you implement a client of type :code:`NumPyClient` you'll need to first call its :code:`to_client()` method. The string :code:`"[::]:8080"` tells the client which server to connect to. In our case we can run the server and the client on the same machine, therefore we use :code:`"[::]:8080"`. If we run a truly federated workload with the server and clients running on different machines, all that needs to change is the :code:`server_address` we point the client at. diff --git a/doc/source/tutorial-quickstart-xgboost.rst b/doc/source/tutorial-quickstart-xgboost.rst index 7eb58da7f2f6..7ac055138814 100644 --- a/doc/source/tutorial-quickstart-xgboost.rst +++ b/doc/source/tutorial-quickstart-xgboost.rst @@ -36,7 +36,7 @@ and then we dive into a more complex example (`full code xgboost-comprehensive < Environment Setup -------------------- -First of all, it is recommended to create a virtual environment and run everything within a `virtualenv `_. +First of all, it is recommended to create a virtual environment and run everything within a :doc:`virtualenv `. We first need to install Flower and Flower Datasets. You can do this by running : @@ -595,9 +595,164 @@ Comprehensive Federated XGBoost Now that you have known how federated XGBoost work with Flower, it's time to run some more comprehensive experiments by customising the experimental settings. In the xgboost-comprehensive example (`full code `_), -we provide more options to define various experimental setups, including data partitioning and centralised/distributed evaluation. +we provide more options to define various experimental setups, including aggregation strategies, data partitioning and centralised/distributed evaluation. +We also support :doc:`Flower simulation ` making it easy to simulate large client cohorts in a resource-aware manner. Let's take a look! +Cyclic training +~~~~~~~~~~~~~~~~~~ + +In addition to bagging aggregation, we offer a cyclic training scheme, which performs FL in a client-by-client fashion. +Instead of aggregating multiple clients, there is only one single client participating in the training per round in the cyclic training scenario. +The trained local XGBoost trees will be passed to the next client as an initialised model for next round's boosting. + +To do this, we first customise a :code:`ClientManager` in :code:`server_utils.py`: + +.. code-block:: python + + class CyclicClientManager(SimpleClientManager): + """Provides a cyclic client selection rule.""" + + def sample( + self, + num_clients: int, + min_num_clients: Optional[int] = None, + criterion: Optional[Criterion] = None, + ) -> List[ClientProxy]: + """Sample a number of Flower ClientProxy instances.""" + + # Block until at least num_clients are connected. + if min_num_clients is None: + min_num_clients = num_clients + self.wait_for(min_num_clients) + + # Sample clients which meet the criterion + available_cids = list(self.clients) + if criterion is not None: + available_cids = [ + cid for cid in available_cids if criterion.select(self.clients[cid]) + ] + + if num_clients > len(available_cids): + log( + INFO, + "Sampling failed: number of available clients" + " (%s) is less than number of requested clients (%s).", + len(available_cids), + num_clients, + ) + return [] + + # Return all available clients + return [self.clients[cid] for cid in available_cids] + +The customised :code:`ClientManager` samples all available clients in each FL round based on the order of connection to the server. +Then, we define a new strategy :code:`FedXgbCyclic` in :code:`flwr.server.strategy.fedxgb_cyclic.py`, +in order to sequentially select only one client in given round and pass the received model to next client. + +.. code-block:: python + + class FedXgbCyclic(FedAvg): + """Configurable FedXgbCyclic strategy implementation.""" + + # pylint: disable=too-many-arguments,too-many-instance-attributes, line-too-long + def __init__( + self, + **kwargs: Any, + ): + self.global_model: Optional[bytes] = None + super().__init__(**kwargs) + + def aggregate_fit( + self, + server_round: int, + results: List[Tuple[ClientProxy, FitRes]], + failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], + ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + """Aggregate fit results using bagging.""" + if not results: + return None, {} + # Do not aggregate if there are failures and failures are not accepted + if not self.accept_failures and failures: + return None, {} + + # Fetch the client model from last round as global model + for _, fit_res in results: + update = fit_res.parameters.tensors + for bst in update: + self.global_model = bst + + return ( + Parameters(tensor_type="", tensors=[cast(bytes, self.global_model)]), + {}, + ) + +Unlike the original :code:`FedAvg`, we don't perform aggregation here. +Instead, we just make a copy of the received client model as global model by overriding :code:`aggregate_fit`. + +Also, the customised :code:`configure_fit` and :code:`configure_evaluate` methods ensure the clients to be sequentially selected given FL round: + +.. code-block:: python + + def configure_fit( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, FitIns]]: + """Configure the next round of training.""" + config = {} + if self.on_fit_config_fn is not None: + # Custom fit config function provided + config = self.on_fit_config_fn(server_round) + fit_ins = FitIns(parameters, config) + + # Sample clients + sample_size, min_num_clients = self.num_fit_clients( + client_manager.num_available() + ) + clients = client_manager.sample( + num_clients=sample_size, + min_num_clients=min_num_clients, + ) + + # Sample the clients sequentially given server_round + sampled_idx = (server_round - 1) % len(clients) + sampled_clients = [clients[sampled_idx]] + + # Return client/config pairs + return [(client, fit_ins) for client in sampled_clients] + + def configure_evaluate( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, EvaluateIns]]: + """Configure the next round of evaluation.""" + # Do not configure federated evaluation if fraction eval is 0. + if self.fraction_evaluate == 0.0: + return [] + + # Parameters and config + config = {} + if self.on_evaluate_config_fn is not None: + # Custom evaluation config function provided + config = self.on_evaluate_config_fn(server_round) + evaluate_ins = EvaluateIns(parameters, config) + + # Sample clients + sample_size, min_num_clients = self.num_evaluation_clients( + client_manager.num_available() + ) + clients = client_manager.sample( + num_clients=sample_size, + min_num_clients=min_num_clients, + ) + + # Sample the clients sequentially given server_round + sampled_idx = (server_round - 1) % len(clients) + sampled_clients = [clients[sampled_idx]] + + # Return client/config pairs + return [(client, evaluate_ins) for client in sampled_clients] + + + Customised data partitioning ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -634,7 +789,7 @@ Currently, we provide four supported partitioner type to simulate the uniformity Customised centralised/distributed evaluation ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -To facilitate centralised evaluation, we define a function in :code:`server.py`: +To facilitate centralised evaluation, we define a function in :code:`server_utils.py`: .. code-block:: python @@ -670,51 +825,265 @@ This function returns a evaluation function which instantiates a :code:`Booster` The evaluation is conducted by calling :code:`eval_set()` method, and the tested AUC value is reported. As for distributed evaluation on the clients, it's same as the quick-start example by -overriding the :code:`evaluate()` method insides the :code:`XgbClient` class in :code:`client.py`. +overriding the :code:`evaluate()` method insides the :code:`XgbClient` class in :code:`client_utils.py`. -Arguments parser -~~~~~~~~~~~~~~~~~~~~~~ +Flower simulation +~~~~~~~~~~~~~~~~~~~~ +We also provide an example code (:code:`sim.py`) to use the simulation capabilities of Flower to simulate federated XGBoost training on either a single machine or a cluster of machines. -In :code:`utils.py`, we define the arguments parsers for clients and server, allowing users to specify different experimental settings. -Let's first see the sever side: +.. code-block:: python + + from logging import INFO + import xgboost as xgb + from tqdm import tqdm + + import flwr as fl + from flwr_datasets import FederatedDataset + from flwr.common.logger import log + from flwr.server.strategy import FedXgbBagging, FedXgbCyclic + + from dataset import ( + instantiate_partitioner, + train_test_split, + transform_dataset_to_dmatrix, + separate_xy, + resplit, + ) + from utils import ( + sim_args_parser, + NUM_LOCAL_ROUND, + BST_PARAMS, + ) + from server_utils import ( + eval_config, + fit_config, + evaluate_metrics_aggregation, + get_evaluate_fn, + CyclicClientManager, + ) + from client_utils import XgbClient + +After importing all required packages, we define a :code:`main()` function to perform the simulation process: .. code-block:: python - import argparse + def main(): + # Parse arguments for experimental settings + args = sim_args_parser() + # Load (HIGGS) dataset and conduct partitioning + partitioner = instantiate_partitioner( + partitioner_type=args.partitioner_type, num_partitions=args.pool_size + ) + fds = FederatedDataset( + dataset="jxie/higgs", + partitioners={"train": partitioner}, + resplitter=resplit, + ) - def server_args_parser(): - """Parse arguments to define experimental settings on server side.""" - parser = argparse.ArgumentParser() + # Load centralised test set + if args.centralised_eval or args.centralised_eval_client: + log(INFO, "Loading centralised test set...") + test_data = fds.load_split("test") + test_data.set_format("numpy") + num_test = test_data.shape[0] + test_dmatrix = transform_dataset_to_dmatrix(test_data) + + # Load partitions and reformat data to DMatrix for xgboost + log(INFO, "Loading client local partitions...") + train_data_list = [] + valid_data_list = [] + + # Load and process all client partitions. This upfront cost is amortized soon + # after the simulation begins since clients wont need to preprocess their partition. + for node_id in tqdm(range(args.pool_size), desc="Extracting client partition"): + # Extract partition for client with node_id + partition = fds.load_partition(node_id=node_id, split="train") + partition.set_format("numpy") + + if args.centralised_eval_client: + # Use centralised test set for evaluation + train_data = partition + num_train = train_data.shape[0] + x_test, y_test = separate_xy(test_data) + valid_data_list.append(((x_test, y_test), num_test)) + else: + # Train/test splitting + train_data, valid_data, num_train, num_val = train_test_split( + partition, test_fraction=args.test_fraction, seed=args.seed + ) + x_valid, y_valid = separate_xy(valid_data) + valid_data_list.append(((x_valid, y_valid), num_val)) - parser.add_argument( - "--pool-size", default=2, type=int, help="Number of total clients." - ) - parser.add_argument( - "--num-rounds", default=5, type=int, help="Number of FL rounds." - ) - parser.add_argument( - "--num-clients-per-round", - default=2, - type=int, - help="Number of clients participate in training each round.", - ) - parser.add_argument( - "--num-evaluate-clients", - default=2, - type=int, - help="Number of clients selected for evaluation.", + x_train, y_train = separate_xy(train_data) + train_data_list.append(((x_train, y_train), num_train)) + +We first load the dataset and perform data partitioning, and the pre-processed data is stored in a :code:`list`. +After the simulation begins, the clients won't need to pre-process their partitions again. + +Then, we define the strategies and other hyper-parameters: + +.. code-block:: python + + # Define strategy + if args.train_method == "bagging": + # Bagging training + strategy = FedXgbBagging( + evaluate_function=get_evaluate_fn(test_dmatrix) + if args.centralised_eval + else None, + fraction_fit=(float(args.num_clients_per_round) / args.pool_size), + min_fit_clients=args.num_clients_per_round, + min_available_clients=args.pool_size, + min_evaluate_clients=args.num_evaluate_clients + if not args.centralised_eval + else 0, + fraction_evaluate=1.0 if not args.centralised_eval else 0.0, + on_evaluate_config_fn=eval_config, + on_fit_config_fn=fit_config, + evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation + if not args.centralised_eval + else None, ) - parser.add_argument( - "--centralised-eval", - action="store_true", - help="Conduct centralised evaluation (True), or client evaluation on hold-out data (False).", + else: + # Cyclic training + strategy = FedXgbCyclic( + fraction_fit=1.0, + min_available_clients=args.pool_size, + fraction_evaluate=1.0, + evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation, + on_evaluate_config_fn=eval_config, + on_fit_config_fn=fit_config, ) - args = parser.parse_args() - return args + # Resources to be assigned to each virtual client + # In this example we use CPU by default + client_resources = { + "num_cpus": args.num_cpus_per_client, + "num_gpus": 0.0, + } + + # Hyper-parameters for xgboost training + num_local_round = NUM_LOCAL_ROUND + params = BST_PARAMS + + # Setup learning rate + if args.train_method == "bagging" and args.scaled_lr: + new_lr = params["eta"] / args.pool_size + params.update({"eta": new_lr}) + +After that, we start the simulation by calling :code:`fl.simulation.start_simulation`: + +.. code-block:: python -This allows user to specify the number of total clients / FL rounds / participating clients / clients for evaluation, + # Start simulation + fl.simulation.start_simulation( + client_fn=get_client_fn( + train_data_list, + valid_data_list, + args.train_method, + params, + num_local_round, + ), + num_clients=args.pool_size, + client_resources=client_resources, + config=fl.server.ServerConfig(num_rounds=args.num_rounds), + strategy=strategy, + client_manager=CyclicClientManager() if args.train_method == "cyclic" else None, + ) + +One of key parameters for :code:`start_simulation` is :code:`client_fn` which returns a function to construct a client. +We define it as follows: + +.. code-block:: python + + def get_client_fn( + train_data_list, valid_data_list, train_method, params, num_local_round + ): + """Return a function to construct a client. + + The VirtualClientEngine will execute this function whenever a client is sampled by + the strategy to participate. + """ + + def client_fn(cid: str) -> fl.client.Client: + """Construct a FlowerClient with its own dataset partition.""" + x_train, y_train = train_data_list[int(cid)][0] + x_valid, y_valid = valid_data_list[int(cid)][0] + + # Reformat data to DMatrix + train_dmatrix = xgb.DMatrix(x_train, label=y_train) + valid_dmatrix = xgb.DMatrix(x_valid, label=y_valid) + + # Fetch the number of examples + num_train = train_data_list[int(cid)][1] + num_val = valid_data_list[int(cid)][1] + + # Create and return client + return XgbClient( + train_dmatrix, + valid_dmatrix, + num_train, + num_val, + num_local_round, + params, + train_method, + ) + + return client_fn + + + +Arguments parser +~~~~~~~~~~~~~~~~~~~~~~ + +In :code:`utils.py`, we define the arguments parsers for clients, server and simulation, allowing users to specify different experimental settings. +Let's first see the sever side: + +.. code-block:: python + + import argparse + + + def server_args_parser(): + """Parse arguments to define experimental settings on server side.""" + parser = argparse.ArgumentParser() + + parser.add_argument( + "--train-method", + default="bagging", + type=str, + choices=["bagging", "cyclic"], + help="Training methods selected from bagging aggregation or cyclic training.", + ) + parser.add_argument( + "--pool-size", default=2, type=int, help="Number of total clients." + ) + parser.add_argument( + "--num-rounds", default=5, type=int, help="Number of FL rounds." + ) + parser.add_argument( + "--num-clients-per-round", + default=2, + type=int, + help="Number of clients participate in training each round.", + ) + parser.add_argument( + "--num-evaluate-clients", + default=2, + type=int, + help="Number of clients selected for evaluation.", + ) + parser.add_argument( + "--centralised-eval", + action="store_true", + help="Conduct centralised evaluation (True), or client evaluation on hold-out data (False).", + ) + + args = parser.parse_args() + return args + +This allows user to specify training strategies / the number of total clients / FL rounds / participating clients / clients for evaluation, and evaluation fashion. Note that with :code:`--centralised-eval`, the sever will do centralised evaluation and all functionalities for client evaluation will be disabled. @@ -723,60 +1092,159 @@ Then, the argument parser on client side: .. code-block:: python def client_args_parser(): - """Parse arguments to define experimental settings on client side.""" - parser = argparse.ArgumentParser() + """Parse arguments to define experimental settings on client side.""" + parser = argparse.ArgumentParser() + + parser.add_argument( + "--train-method", + default="bagging", + type=str, + choices=["bagging", "cyclic"], + help="Training methods selected from bagging aggregation or cyclic training.", + ) + parser.add_argument( + "--num-partitions", default=10, type=int, help="Number of partitions." + ) + parser.add_argument( + "--partitioner-type", + default="uniform", + type=str, + choices=["uniform", "linear", "square", "exponential"], + help="Partitioner types.", + ) + parser.add_argument( + "--node-id", + default=0, + type=int, + help="Node ID used for the current client.", + ) + parser.add_argument( + "--seed", default=42, type=int, help="Seed used for train/test splitting." + ) + parser.add_argument( + "--test-fraction", + default=0.2, + type=float, + help="Test fraction for train/test splitting.", + ) + parser.add_argument( + "--centralised-eval", + action="store_true", + help="Conduct evaluation on centralised test set (True), or on hold-out data (False).", + ) + parser.add_argument( + "--scaled-lr", + action="store_true", + help="Perform scaled learning rate based on the number of clients (True).", + ) + + args = parser.parse_args() + return args - parser.add_argument( - "--num-partitions", default=10, type=int, help="Number of partitions." - ) - parser.add_argument( - "--partitioner-type", - default="uniform", - type=str, - choices=["uniform", "linear", "square", "exponential"], - help="Partitioner types.", - ) - parser.add_argument( - "--node-id", - default=0, - type=int, - help="Node ID used for the current client.", - ) - parser.add_argument( - "--seed", default=42, type=int, help="Seed used for train/test splitting." - ) - parser.add_argument( - "--test-fraction", - default=0.2, - type=float, - help="Test fraction for train/test splitting.", - ) - parser.add_argument( - "--centralised-eval", - action="store_true", - help="Conduct centralised evaluation (True), or client evaluation on hold-out data (False).", - ) +This defines various options for client data partitioning. +Besides, clients also have an option to conduct evaluation on centralised test set by setting :code:`--centralised-eval`, +as well as an option to perform scaled learning rate based on the number of clients by setting :code:`--scaled-lr`. - args = parser.parse_args() - return args +We also have an argument parser for simulation: -This defines various options for client data partitioning. -Besides, clients also have a option to conduct evaluation on centralised test set by setting :code:`--centralised-eval`. +.. code-block:: python + + def sim_args_parser(): + """Parse arguments to define experimental settings on server side.""" + parser = argparse.ArgumentParser() + + parser.add_argument( + "--train-method", + default="bagging", + type=str, + choices=["bagging", "cyclic"], + help="Training methods selected from bagging aggregation or cyclic training.", + ) + + # Server side + parser.add_argument( + "--pool-size", default=5, type=int, help="Number of total clients." + ) + parser.add_argument( + "--num-rounds", default=30, type=int, help="Number of FL rounds." + ) + parser.add_argument( + "--num-clients-per-round", + default=5, + type=int, + help="Number of clients participate in training each round.", + ) + parser.add_argument( + "--num-evaluate-clients", + default=5, + type=int, + help="Number of clients selected for evaluation.", + ) + parser.add_argument( + "--centralised-eval", + action="store_true", + help="Conduct centralised evaluation (True), or client evaluation on hold-out data (False).", + ) + parser.add_argument( + "--num-cpus-per-client", + default=2, + type=int, + help="Number of CPUs used for per client.", + ) + + # Client side + parser.add_argument( + "--partitioner-type", + default="uniform", + type=str, + choices=["uniform", "linear", "square", "exponential"], + help="Partitioner types.", + ) + parser.add_argument( + "--seed", default=42, type=int, help="Seed used for train/test splitting." + ) + parser.add_argument( + "--test-fraction", + default=0.2, + type=float, + help="Test fraction for train/test splitting.", + ) + parser.add_argument( + "--centralised-eval-client", + action="store_true", + help="Conduct evaluation on centralised test set (True), or on hold-out data (False).", + ) + parser.add_argument( + "--scaled-lr", + action="store_true", + help="Perform scaled learning rate based on the number of clients (True).", + ) + + args = parser.parse_args() + return args + +This integrates all arguments for both client and server sides. Example commands ~~~~~~~~~~~~~~~~~~~~~ -To run a centralised evaluated experiment on 5 clients with exponential distribution for 50 rounds, +To run a centralised evaluated experiment with bagging strategy on 5 clients with exponential distribution for 50 rounds, we first start the server as below: .. code-block:: shell - $ python3 server.py --pool-size=5 --num-rounds=50 --num-clients-per-round=5 --centralised-eval + $ python3 server.py --train-method=bagging --pool-size=5 --num-rounds=50 --num-clients-per-round=5 --centralised-eval Then, on each client terminal, we start the clients: .. code-block:: shell - $ python3 clients.py --num-partitions=5 --partitioner-type=exponential --node-id=NODE_ID + $ python3 clients.py --train-method=bagging --num-partitions=5 --partitioner-type=exponential --node-id=NODE_ID + +To run the same experiment with Flower simulation: + +.. code-block:: shell + + $ python3 sim.py --train-method=bagging --pool-size=5 --num-rounds=50 --num-clients-per-round=5 --partitioner-type=exponential --centralised-eval The full `code `_ for this comprehensive example can be found in :code:`examples/xgboost-comprehensive`. diff --git a/doc/source/tutorial-series-build-a-strategy-from-scratch-pytorch.ipynb b/doc/source/tutorial-series-build-a-strategy-from-scratch-pytorch.ipynb index 5b2236468909..c5fc777e7f26 100644 --- a/doc/source/tutorial-series-build-a-strategy-from-scratch-pytorch.ipynb +++ b/doc/source/tutorial-series-build-a-strategy-from-scratch-pytorch.ipynb @@ -7,11 +7,11 @@ "source": [ "# Build a strategy from scratch\n", "\n", - "Welcome to the third part of the Flower federated learning tutorial. In previous parts of this tutorial, we introduced federated learning with PyTorch and Flower ([part 1](https://flower.dev/docs/framework/tutorial-get-started-with-flower-pytorch.html)) and we learned how strategies can be used to customize the execution on both the server and the clients ([part 2](https://flower.dev/docs/framework/tutorial-use-a-federated-learning-strategy-pytorch.html)).\n", + "Welcome to the third part of the Flower federated learning tutorial. In previous parts of this tutorial, we introduced federated learning with PyTorch and Flower ([part 1](https://flower.ai/docs/framework/tutorial-get-started-with-flower-pytorch.html)) and we learned how strategies can be used to customize the execution on both the server and the clients ([part 2](https://flower.ai/docs/framework/tutorial-use-a-federated-learning-strategy-pytorch.html)).\n", "\n", - "In this notebook, we'll continue to customize the federated learning system we built previously by creating a custom version of FedAvg (again, using [Flower](https://flower.dev/) and [PyTorch](https://pytorch.org/)).\n", + "In this notebook, we'll continue to customize the federated learning system we built previously by creating a custom version of FedAvg (again, using [Flower](https://flower.ai/) and [PyTorch](https://pytorch.org/)).\n", "\n", - "> [Star Flower on GitHub](https://github.com/adap/flower) ⭐️ and join the Flower community on Slack to connect, ask questions, and get help: [Join Slack](https://flower.dev/join-slack) 🌼 We'd love to hear from you in the `#introductions` channel! And if anything is unclear, head over to the `#questions` channel.\n", + "> [Star Flower on GitHub](https://github.com/adap/flower) ⭐️ and join the Flower community on Slack to connect, ask questions, and get help: [Join Slack](https://flower.ai/join-slack) 🌼 We'd love to hear from you in the `#introductions` channel! And if anything is unclear, head over to the `#questions` channel.\n", "\n", "Let's build a new `Strategy` from scratch!" ] @@ -489,11 +489,11 @@ "source": [ "## Next steps\n", "\n", - "Before you continue, make sure to join the Flower community on Slack: [Join Slack](https://flower.dev/join-slack/)\n", + "Before you continue, make sure to join the Flower community on Slack: [Join Slack](https://flower.ai/join-slack/)\n", "\n", "There's a dedicated `#questions` channel if you need help, but we'd also love to hear who you are in `#introductions`!\n", "\n", - "The [Flower Federated Learning Tutorial - Part 4](https://flower.dev/docs/framework/tutorial-customize-the-client-pytorch.html) introduces `Client`, the flexible API underlying `NumPyClient`." + "The [Flower Federated Learning Tutorial - Part 4](https://flower.ai/docs/framework/tutorial-customize-the-client-pytorch.html) introduces `Client`, the flexible API underlying `NumPyClient`." ] } ], diff --git a/doc/source/tutorial-series-customize-the-client-pytorch.ipynb b/doc/source/tutorial-series-customize-the-client-pytorch.ipynb index 0ff67de6f51d..dbdd1094173c 100644 --- a/doc/source/tutorial-series-customize-the-client-pytorch.ipynb +++ b/doc/source/tutorial-series-customize-the-client-pytorch.ipynb @@ -7,11 +7,11 @@ "source": [ "# Customize the client\n", "\n", - "Welcome to the fourth part of the Flower federated learning tutorial. In the previous parts of this tutorial, we introduced federated learning with PyTorch and Flower ([part 1](https://flower.dev/docs/framework/tutorial-get-started-with-flower-pytorch.html)), we learned how strategies can be used to customize the execution on both the server and the clients ([part 2](https://flower.dev/docs/framework/tutorial-use-a-federated-learning-strategy-pytorch.html)), and we built our own custom strategy from scratch ([part 3](https://flower.dev/docs/framework/tutorial-build-a-strategy-from-scratch-pytorch.html)).\n", + "Welcome to the fourth part of the Flower federated learning tutorial. In the previous parts of this tutorial, we introduced federated learning with PyTorch and Flower ([part 1](https://flower.ai/docs/framework/tutorial-get-started-with-flower-pytorch.html)), we learned how strategies can be used to customize the execution on both the server and the clients ([part 2](https://flower.ai/docs/framework/tutorial-use-a-federated-learning-strategy-pytorch.html)), and we built our own custom strategy from scratch ([part 3](https://flower.ai/docs/framework/tutorial-build-a-strategy-from-scratch-pytorch.html)).\n", "\n", "In this notebook, we revisit `NumPyClient` and introduce a new baseclass for building clients, simply named `Client`. In previous parts of this tutorial, we've based our client on `NumPyClient`, a convenience class which makes it easy to work with machine learning libraries that have good NumPy interoperability. With `Client`, we gain a lot of flexibility that we didn't have before, but we'll also have to do a few things the we didn't have to do before.\n", "\n", - "> [Star Flower on GitHub](https://github.com/adap/flower) ⭐️ and join the Flower community on Slack to connect, ask questions, and get help: [Join Slack](https://flower.dev/join-slack) 🌼 We'd love to hear from you in the `#introductions` channel! And if anything is unclear, head over to the `#questions` channel.\n", + "> [Star Flower on GitHub](https://github.com/adap/flower) ⭐️ and join the Flower community on Slack to connect, ask questions, and get help: [Join Slack](https://flower.ai/join-slack) 🌼 We'd love to hear from you in the `#introductions` channel! And if anything is unclear, head over to the `#questions` channel.\n", "\n", "Let's go deeper and see what it takes to move from `NumPyClient` to `Client`!" ] @@ -500,7 +500,7 @@ " # We convert our ndarray into a sparse matrix\n", " ndarray = torch.tensor(ndarray).to_sparse_csr()\n", "\n", - " # And send it by utilizng the sparse matrix attributes\n", + " # And send it byutilizing the sparse matrix attributes\n", " # WARNING: NEVER set allow_pickle to true.\n", " # Reason: loading pickled data can execute arbitrary code\n", " # Source: https://numpy.org/doc/stable/reference/generated/numpy.save.html\n", @@ -550,7 +550,7 @@ "source": [ "### Client-side\n", "\n", - "To be able to able to serialize our `ndarray`s into sparse parameters, we will just have to call our custom functions in our `flwr.client.Client`.\n", + "To be able to serialize our `ndarray`s into sparse parameters, we will just have to call our custom functions in our `flwr.client.Client`.\n", "\n", "Indeed, in `get_parameters` we need to serialize the parameters we got from our network using our custom `ndarrays_to_sparse_parameters` defined above.\n", "\n", @@ -813,7 +813,7 @@ " for _, fit_res in results\n", " ]\n", "\n", - " # We serialize the aggregated result using our cutom method\n", + " # We serialize the aggregated result using our custom method\n", " parameters_aggregated = ndarrays_to_sparse_parameters(\n", " aggregate(weights_results)\n", " )\n", @@ -869,16 +869,16 @@ "source": [ "## Next steps\n", "\n", - "Before you continue, make sure to join the Flower community on Slack: [Join Slack](https://flower.dev/join-slack/)\n", + "Before you continue, make sure to join the Flower community on Slack: [Join Slack](https://flower.ai/join-slack/)\n", "\n", "There's a dedicated `#questions` channel if you need help, but we'd also love to hear who you are in `#introductions`!\n", "\n", "This is the final part of the Flower tutorial (for now!), congratulations! You're now well equipped to understand the rest of the documentation. There are many topics we didn't cover in the tutorial, we recommend the following resources:\n", "\n", - "- [Read Flower Docs](https://flower.dev/docs/)\n", + "- [Read Flower Docs](https://flower.ai/docs/)\n", "- [Check out Flower Code Examples](https://github.com/adap/flower/tree/main/examples)\n", - "- [Use Flower Baselines for your research](https://flower.dev/docs/baselines/)\n", - "- [Watch Flower Summit 2023 videos](https://flower.dev/conf/flower-summit-2023/)\n" + "- [Use Flower Baselines for your research](https://flower.ai/docs/baselines/)\n", + "- [Watch Flower Summit 2023 videos](https://flower.ai/conf/flower-summit-2023/)\n" ] } ], diff --git a/doc/source/tutorial-series-get-started-with-flower-pytorch.ipynb b/doc/source/tutorial-series-get-started-with-flower-pytorch.ipynb old mode 100755 new mode 100644 index 41c9254e9d69..2b8dd382bb79 --- a/doc/source/tutorial-series-get-started-with-flower-pytorch.ipynb +++ b/doc/source/tutorial-series-get-started-with-flower-pytorch.ipynb @@ -9,9 +9,9 @@ "\n", "Welcome to the Flower federated learning tutorial!\n", "\n", - "In this notebook, we'll build a federated learning system using Flower and PyTorch. In part 1, we use PyTorch for the model training pipeline and data loading. In part 2, we continue to federate the PyTorch-based pipeline using Flower.\n", + "In this notebook, we'll build a federated learning system using Flower, [Flower Datasets](https://flower.ai/docs/datasets/) and PyTorch. In part 1, we use PyTorch for the model training pipeline and data loading. In part 2, we continue to federate the PyTorch-based pipeline using Flower.\n", "\n", - "> [Star Flower on GitHub](https://github.com/adap/flower) ⭐️ and join the Flower community on Slack to connect, ask questions, and get help: [Join Slack](https://flower.dev/join-slack) 🌼 We'd love to hear from you in the `#introductions` channel! And if anything is unclear, head over to the `#questions` channel.\n", + "> [Star Flower on GitHub](https://github.com/adap/flower) ⭐️ and join the Flower community on Slack to connect, ask questions, and get help: [Join Slack](https://flower.ai/join-slack) 🌼 We'd love to hear from you in the `#introductions` channel! And if anything is unclear, head over to the `#questions` channel.\n", "\n", "Let's get stated!" ] @@ -31,7 +31,7 @@ "source": [ "### Installing dependencies\n", "\n", - "Next, we install the necessary packages for PyTorch (`torch` and `torchvision`) and Flower (`flwr`):" + "Next, we install the necessary packages for PyTorch (`torch` and `torchvision`), Flower Datasets (`flwr-datasets`) and Flower (`flwr`):" ] }, { @@ -40,7 +40,7 @@ "metadata": {}, "outputs": [], "source": [ - "!pip install -q flwr[simulation] torch torchvision matplotlib" + "!pip install -q flwr[simulation] flwr_datasets[vision] torch torchvision matplotlib" ] }, { @@ -64,25 +64,26 @@ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", - "import torchvision\n", "import torchvision.transforms as transforms\n", - "from torch.utils.data import DataLoader, random_split\n", - "from torchvision.datasets import CIFAR10\n", + "from datasets.utils.logging import disable_progress_bar\n", + "from torch.utils.data import DataLoader\n", "\n", "import flwr as fl\n", "from flwr.common import Metrics\n", + "from flwr_datasets import FederatedDataset\n", "\n", "DEVICE = torch.device(\"cpu\") # Try \"cuda\" to train on GPU\n", "print(\n", " f\"Training on {DEVICE} using PyTorch {torch.__version__} and Flower {fl.__version__}\"\n", - ")" + ")\n", + "disable_progress_bar()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "It is possible to switch to a runtime that has GPU acceleration enabled (on Google Colab: `Runtime > Change runtime type > Hardware acclerator: GPU > Save`). Note, however, that Google Colab is not always able to offer GPU acceleration. If you see an error related to GPU availability in one of the following sections, consider switching back to CPU-based execution by setting `DEVICE = torch.device(\"cpu\")`. If the runtime has GPU acceleration enabled, you should see the output `Training on cuda`, otherwise it'll say `Training on cpu`." + "It is possible to switch to a runtime that has GPU acceleration enabled (on Google Colab: `Runtime > Change runtime type > Hardware accelerator: GPU > Save`). Note, however, that Google Colab is not always able to offer GPU acceleration. If you see an error related to GPU availability in one of the following sections, consider switching back to CPU-based execution by setting `DEVICE = torch.device(\"cpu\")`. If the runtime has GPU acceleration enabled, you should see the output `Training on cuda`, otherwise it'll say `Training on cpu`." ] }, { @@ -92,27 +93,7 @@ "\n", "### Loading the data\n", "\n", - "Federated learning can be applied to many different types of tasks across different domains. In this tutorial, we introduce federated learning by training a simple convolutional neural network (CNN) on the popular CIFAR-10 dataset. CIFAR-10 can be used to train image classifiers that distinguish between images from ten different classes:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "CLASSES = (\n", - " \"plane\",\n", - " \"car\",\n", - " \"bird\",\n", - " \"cat\",\n", - " \"deer\",\n", - " \"dog\",\n", - " \"frog\",\n", - " \"horse\",\n", - " \"ship\",\n", - " \"truck\",\n", - ")" + "Federated learning can be applied to many different types of tasks across different domains. In this tutorial, we introduce federated learning by training a simple convolutional neural network (CNN) on the popular CIFAR-10 dataset. CIFAR-10 can be used to train image classifiers that distinguish between images from ten different classes: 'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', and 'truck'." ] }, { @@ -121,16 +102,7 @@ "source": [ "We simulate having multiple datasets from multiple organizations (also called the \"cross-silo\" setting in federated learning) by splitting the original CIFAR-10 dataset into multiple partitions. Each partition will represent the data from a single organization. We're doing this purely for experimentation purposes, in the real world there's no need for data splitting because each organization already has their own data (so the data is naturally partitioned).\n", "\n", - "Each organization will act as a client in the federated learning system. So having ten organizations participate in a federation means having ten clients connected to the federated learning server:\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "NUM_CLIENTS = 10" + "Each organization will act as a client in the federated learning system. So having ten organizations participate in a federation means having ten clients connected to the federated learning server.\n" ] }, { @@ -138,7 +110,7 @@ "metadata": {}, "source": [ "\n", - "Let's now load the CIFAR-10 training and test set, partition them into ten smaller datasets (each split into training and validation set), and wrap the resulting partitions by creating a PyTorch `DataLoader` for each of them:" + "Let's now create the Federated Dataset abstraction that from `flwr-datasets` that partitions the CIFAR-10. We will create small training and test set for each edge device and wrap each of them into a PyTorch `DataLoader`:" ] }, { @@ -147,32 +119,36 @@ "metadata": {}, "outputs": [], "source": [ + "NUM_CLIENTS = 10\n", "BATCH_SIZE = 32\n", "\n", "\n", "def load_datasets():\n", - " # Download and transform CIFAR-10 (train and test)\n", - " transform = transforms.Compose(\n", - " [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]\n", - " )\n", - " trainset = CIFAR10(\"./dataset\", train=True, download=True, transform=transform)\n", - " testset = CIFAR10(\"./dataset\", train=False, download=True, transform=transform)\n", - "\n", - " # Split training set into 10 partitions to simulate the individual dataset\n", - " partition_size = len(trainset) // NUM_CLIENTS\n", - " lengths = [partition_size] * NUM_CLIENTS\n", - " datasets = random_split(trainset, lengths, torch.Generator().manual_seed(42))\n", - "\n", - " # Split each partition into train/val and create DataLoader\n", + " fds = FederatedDataset(dataset=\"cifar10\", partitioners={\"train\": NUM_CLIENTS})\n", + "\n", + " def apply_transforms(batch):\n", + " # Instead of passing transforms to CIFAR10(..., transform=transform)\n", + " # we will use this function to dataset.with_transform(apply_transforms)\n", + " # The transforms object is exactly the same\n", + " transform = transforms.Compose(\n", + " [\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n", + " ]\n", + " )\n", + " batch[\"img\"] = [transform(img) for img in batch[\"img\"]]\n", + " return batch\n", + "\n", + " # Create train/val for each partition and wrap it into DataLoader\n", " trainloaders = []\n", " valloaders = []\n", - " for ds in datasets:\n", - " len_val = len(ds) // 10 # 10 % validation set\n", - " len_train = len(ds) - len_val\n", - " lengths = [len_train, len_val]\n", - " ds_train, ds_val = random_split(ds, lengths, torch.Generator().manual_seed(42))\n", - " trainloaders.append(DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True))\n", - " valloaders.append(DataLoader(ds_val, batch_size=BATCH_SIZE))\n", + " for partition_id in range(NUM_CLIENTS):\n", + " partition = fds.load_partition(partition_id, \"train\")\n", + " partition = partition.with_transform(apply_transforms)\n", + " partition = partition.train_test_split(train_size=0.8)\n", + " trainloaders.append(DataLoader(partition[\"train\"], batch_size=BATCH_SIZE))\n", + " valloaders.append(DataLoader(partition[\"test\"], batch_size=BATCH_SIZE))\n", + " testset = fds.load_split(\"test\").with_transform(apply_transforms)\n", " testloader = DataLoader(testset, batch_size=BATCH_SIZE)\n", " return trainloaders, valloaders, testloader\n", "\n", @@ -195,8 +171,8 @@ "metadata": {}, "outputs": [], "source": [ - "images, labels = next(iter(trainloaders[0]))\n", - "\n", + "batch = next(iter(trainloaders[0]))\n", + "images, labels = batch[\"img\"], batch[\"label\"]\n", "# Reshape and convert images to a NumPy array\n", "# matplotlib requires images with the shape (height, width, 3)\n", "images = images.permute(0, 2, 3, 1).numpy()\n", @@ -209,7 +185,7 @@ "# Loop over the images and plot them\n", "for i, ax in enumerate(axs.flat):\n", " ax.imshow(images[i])\n", - " ax.set_title(CLASSES[labels[i]])\n", + " ax.set_title(trainloaders[0].dataset.features[\"label\"].int2str([labels[i]])[0])\n", " ax.axis(\"off\")\n", "\n", "# Show the plot\n", @@ -294,8 +270,8 @@ " net.train()\n", " for epoch in range(epochs):\n", " correct, total, epoch_loss = 0, 0, 0.0\n", - " for images, labels in trainloader:\n", - " images, labels = images.to(DEVICE), labels.to(DEVICE)\n", + " for batch in trainloader:\n", + " images, labels = batch[\"img\"].to(DEVICE), batch[\"label\"].to(DEVICE)\n", " optimizer.zero_grad()\n", " outputs = net(images)\n", " loss = criterion(outputs, labels)\n", @@ -317,8 +293,8 @@ " correct, total, loss = 0, 0, 0.0\n", " net.eval()\n", " with torch.no_grad():\n", - " for images, labels in testloader:\n", - " images, labels = images.to(DEVICE), labels.to(DEVICE)\n", + " for batch in testloader:\n", + " images, labels = batch[\"img\"].to(DEVICE), batch[\"label\"].to(DEVICE)\n", " outputs = net(images)\n", " loss += criterion(outputs, labels).item()\n", " _, predicted = torch.max(outputs.data, 1)\n", @@ -392,14 +368,14 @@ "metadata": {}, "outputs": [], "source": [ - "def get_parameters(net) -> List[np.ndarray]:\n", - " return [val.cpu().numpy() for _, val in net.state_dict().items()]\n", - "\n", - "\n", "def set_parameters(net, parameters: List[np.ndarray]):\n", " params_dict = zip(net.state_dict().keys(), parameters)\n", " state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})\n", - " net.load_state_dict(state_dict, strict=True)" + " net.load_state_dict(state_dict, strict=True)\n", + "\n", + "\n", + "def get_parameters(net) -> List[np.ndarray]:\n", + " return [val.cpu().numpy() for _, val in net.state_dict().items()]" ] }, { @@ -477,7 +453,7 @@ " valloader = valloaders[int(cid)]\n", "\n", " # Create a single Flower client representing a single organization\n", - " return FlowerClient(net, trainloader, valloader)" + " return FlowerClient(net, trainloader, valloader).to_client()" ] }, { @@ -508,10 +484,14 @@ " min_available_clients=10, # Wait until all 10 clients are available\n", ")\n", "\n", - "# Specify client resources if you need GPU (defaults to 1 CPU and 0 GPU)\n", - "client_resources = None\n", + "# Specify the resources each of your clients need. By default, each\n", + "# client will be allocated 1x CPU and 0x GPUs\n", + "client_resources = {\"num_cpus\": 1, \"num_gpus\": 0.0}\n", "if DEVICE.type == \"cuda\":\n", - " client_resources = {\"num_gpus\": 1}\n", + " # here we are assigning an entire GPU for each client.\n", + " client_resources = {\"num_cpus\": 1, \"num_gpus\": 1.0}\n", + " # Refer to our documentation for more details about Flower Simulations\n", + " # and how to setup these `client_resources`.\n", "\n", "# Start simulation\n", "fl.simulation.start_simulation(\n", @@ -625,11 +605,11 @@ "source": [ "## Next steps\n", "\n", - "Before you continue, make sure to join the Flower community on Slack: [Join Slack](https://flower.dev/join-slack/)\n", + "Before you continue, make sure to join the Flower community on Slack: [Join Slack](https://flower.ai/join-slack/)\n", "\n", "There's a dedicated `#questions` channel if you need help, but we'd also love to hear who you are in `#introductions`!\n", "\n", - "The [Flower Federated Learning Tutorial - Part 2](https://flower.dev/docs/framework/tutorial-use-a-federated-learning-strategy-pytorch.html) goes into more depth about strategies and all the advanced things you can build with them." + "The [Flower Federated Learning Tutorial - Part 2](https://flower.ai/docs/framework/tutorial-use-a-federated-learning-strategy-pytorch.html) goes into more depth about strategies and all the advanced things you can build with them.\n" ] } ], @@ -640,7 +620,7 @@ "toc_visible": true }, "kernelspec": { - "display_name": "flower-3.7.12", + "display_name": "flwr", "language": "python", "name": "python3" } diff --git a/doc/source/tutorial-series-use-a-federated-learning-strategy-pytorch.ipynb b/doc/source/tutorial-series-use-a-federated-learning-strategy-pytorch.ipynb index 06f53cd8e1b1..e20a8d83f674 100644 --- a/doc/source/tutorial-series-use-a-federated-learning-strategy-pytorch.ipynb +++ b/doc/source/tutorial-series-use-a-federated-learning-strategy-pytorch.ipynb @@ -7,11 +7,11 @@ "source": [ "# Use a federated learning strategy\n", "\n", - "Welcome to the next part of the federated learning tutorial. In previous parts of this tutorial, we introduced federated learning with PyTorch and Flower ([part 1](https://flower.dev/docs/framework/tutorial-get-started-with-flower-pytorch.html)).\n", + "Welcome to the next part of the federated learning tutorial. In previous parts of this tutorial, we introduced federated learning with PyTorch and Flower ([part 1](https://flower.ai/docs/framework/tutorial-get-started-with-flower-pytorch.html)).\n", "\n", - "In this notebook, we'll begin to customize the federated learning system we built in the introductory notebook (again, using [Flower](https://flower.dev/) and [PyTorch](https://pytorch.org/)).\n", + "In this notebook, we'll begin to customize the federated learning system we built in the introductory notebook (again, using [Flower](https://flower.ai/) and [PyTorch](https://pytorch.org/)).\n", "\n", - "> [Star Flower on GitHub](https://github.com/adap/flower) ⭐️ and join the Flower community on Slack to connect, ask questions, and get help: [Join Slack](https://flower.dev/join-slack) 🌼 We'd love to hear from you in the `#introductions` channel! And if anything is unclear, head over to the `#questions` channel.\n", + "> [Star Flower on GitHub](https://github.com/adap/flower) ⭐️ and join the Flower community on Slack to connect, ask questions, and get help: [Join Slack](https://flower.ai/join-slack) 🌼 We'd love to hear from you in the `#introductions` channel! And if anything is unclear, head over to the `#questions` channel.\n", "\n", "Let's move beyond FedAvg with Flower strategies!" ] @@ -614,11 +614,11 @@ "source": [ "## Next steps\n", "\n", - "Before you continue, make sure to join the Flower community on Slack: [Join Slack](https://flower.dev/join-slack/)\n", + "Before you continue, make sure to join the Flower community on Slack: [Join Slack](https://flower.ai/join-slack/)\n", "\n", "There's a dedicated `#questions` channel if you need help, but we'd also love to hear who you are in `#introductions`!\n", "\n", - "The [Flower Federated Learning Tutorial - Part 3](https://flower.dev/docs/framework/tutorial-build-a-strategy-from-scratch-pytorch.html) shows how to build a fully custom `Strategy` from scratch." + "The [Flower Federated Learning Tutorial - Part 3](https://flower.ai/docs/framework/tutorial-build-a-strategy-from-scratch-pytorch.html) shows how to build a fully custom `Strategy` from scratch." ] } ], diff --git a/doc/source/tutorial-series-what-is-federated-learning.ipynb b/doc/source/tutorial-series-what-is-federated-learning.ipynb index 3f7e383b9fbc..7d77d1770457 100755 --- a/doc/source/tutorial-series-what-is-federated-learning.ipynb +++ b/doc/source/tutorial-series-what-is-federated-learning.ipynb @@ -13,7 +13,7 @@ "\n", "🧑‍🏫 This tutorial starts at zero and expects no familiarity with federated learning. Only a basic understanding of data science and Python programming is assumed.\n", "\n", - "> [Star Flower on GitHub](https://github.com/adap/flower) ⭐️ and join the open-source Flower community on Slack to connect, ask questions, and get help: [Join Slack](https://flower.dev/join-slack) 🌼 We'd love to hear from you in the `#introductions` channel! And if anything is unclear, head over to the `#questions` channel.\n", + "> [Star Flower on GitHub](https://github.com/adap/flower) ⭐️ and join the open-source Flower community on Slack to connect, ask questions, and get help: [Join Slack](https://flower.ai/join-slack) 🌼 We'd love to hear from you in the `#introductions` channel! And if anything is unclear, head over to the `#questions` channel.\n", "\n", "Let's get started!" ] @@ -98,7 +98,7 @@ "- Location data from your electric car to make better range prediction\n", "- End-to-end encrypted messages to train better auto-complete models\n", "\n", - "The popularity of privacy-enhancing systems like the [Brave](https://brave.com/) browser or the [Signal](https://signal.org/) messenger shows that users care about privacy. In fact, they choose the privacy-enhancing version over other alternatives, if such an alernative exists. But what can we do to apply machine learning and data science to these cases to utilize private data? After all, these are all areas that would benefit significantly from recent advances in AI." + "The popularity of privacy-enhancing systems like the [Brave](https://brave.com/) browser or the [Signal](https://signal.org/) messenger shows that users care about privacy. In fact, they choose the privacy-enhancing version over other alternatives, if such an alternative exists. But what can we do to apply machine learning and data science to these cases to utilize private data? After all, these are all areas that would benefit significantly from recent advances in AI." ] }, { @@ -217,11 +217,11 @@ "source": [ "## Next steps\n", "\n", - "Before you continue, make sure to join the Flower community on Slack: [Join Slack](https://flower.dev/join-slack/)\n", + "Before you continue, make sure to join the Flower community on Slack: [Join Slack](https://flower.ai/join-slack/)\n", "\n", "There's a dedicated `#questions` channel if you need help, but we'd also love to hear who you are in `#introductions`!\n", "\n", - "The [Flower Federated Learning Tutorial - Part 1](https://flower.dev/docs/framework/tutorial-get-started-with-flower-pytorch.html) shows how to build a simple federated learning system with PyTorch and Flower." + "The [Flower Federated Learning Tutorial - Part 1](https://flower.ai/docs/framework/tutorial-get-started-with-flower-pytorch.html) shows how to build a simple federated learning system with PyTorch and Flower." ] } ], diff --git a/e2e/bare-https/client.py b/e2e/bare-https/client.py index 20a5b4875ddf..b4570b36512d 100644 --- a/e2e/bare-https/client.py +++ b/e2e/bare-https/client.py @@ -25,15 +25,15 @@ def evaluate(self, parameters, config): def client_fn(cid): return FlowerClient().to_client() -flower = fl.flower.Flower( +app = fl.client.ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_numpy_client( + fl.client.start_client( server_address="127.0.0.1:8080", - client=FlowerClient(), + client=FlowerClient().to_client(), root_certificates=Path("certificates/ca.crt").read_bytes(), insecure=False, ) diff --git a/e2e/bare-https/driver.py b/e2e/bare-https/driver.py index 5c44e4c641ae..f7bfeb613f6a 100644 --- a/e2e/bare-https/driver.py +++ b/e2e/bare-https/driver.py @@ -3,7 +3,7 @@ # Start Flower server -hist = fl.driver.start_driver( +hist = fl.server.start_driver( server_address="127.0.0.1:9091", config=fl.server.ServerConfig(num_rounds=3), root_certificates=Path("certificates/ca.crt").read_bytes(), diff --git a/e2e/bare-https/pyproject.toml b/e2e/bare-https/pyproject.toml index 9489a43195f9..3afb7b57a084 100644 --- a/e2e/bare-https/pyproject.toml +++ b/e2e/bare-https/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" name = "bare_https_test" version = "0.1.0" description = "HTTPS-enabled bare Federated Learning test with Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = "^3.8" diff --git a/e2e/bare/client.py b/e2e/bare/client.py index 8e5c3adff5e6..c291fb0963e4 100644 --- a/e2e/bare/client.py +++ b/e2e/bare/client.py @@ -3,6 +3,8 @@ import flwr as fl import numpy as np +from flwr.common import ConfigsRecord + SUBSET_SIZE = 1000 STATE_VAR = 'timestamp' @@ -18,13 +20,15 @@ def get_parameters(self, config): def _record_timestamp_to_state(self): """Record timestamp to client's state.""" t_stamp = datetime.now().timestamp() - if STATE_VAR in self.state.state: - self.state.state[STATE_VAR] += f",{t_stamp}" - else: - self.state.state[STATE_VAR] = str(t_stamp) + value = str(t_stamp) + if STATE_VAR in self.context.state.configs_records.keys(): + value = self.context.state.configs_records[STATE_VAR][STATE_VAR] # type: ignore + value += f",{t_stamp}" + + self.context.state.configs_records[STATE_VAR] = ConfigsRecord({STATE_VAR: value}) def _retrieve_timestamp_from_state(self): - return self.state.state[STATE_VAR] + return self.context.state.configs_records[STATE_VAR][STATE_VAR] def fit(self, parameters, config): model_params = parameters @@ -42,10 +46,10 @@ def evaluate(self, parameters, config): def client_fn(cid): return FlowerClient().to_client() -flower = fl.flower.Flower( +app = fl.client.ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=FlowerClient()) + fl.client.start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client()) diff --git a/e2e/bare/driver.py b/e2e/bare/driver.py index 6bd61e344ad1..defc2ad56213 100644 --- a/e2e/bare/driver.py +++ b/e2e/bare/driver.py @@ -2,7 +2,7 @@ # Start Flower server -hist = fl.driver.start_driver( +hist = fl.server.start_driver( server_address="0.0.0.0:9091", config=fl.server.ServerConfig(num_rounds=3), ) diff --git a/e2e/bare/pyproject.toml b/e2e/bare/pyproject.toml index b9a4028806c3..45ce7ea333af 100644 --- a/e2e/bare/pyproject.toml +++ b/e2e/bare/pyproject.toml @@ -6,8 +6,8 @@ build-backend = "poetry.core.masonry.api" name = "bare_test" version = "0.1.0" description = "Bare Federated Learning test with Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = "^3.8" -flwr = { path = "../../", develop = true, extras = ["simulation"] } +flwr = { path = "../../", develop = true, extras = ["simulation", "rest"] } diff --git a/e2e/fastai/client.py b/e2e/fastai/client.py index 4425fed25277..c4bfb89c2dde 100644 --- a/e2e/fastai/client.py +++ b/e2e/fastai/client.py @@ -53,14 +53,14 @@ def client_fn(cid): return FlowerClient().to_client() -flower = fl.flower.Flower( +app = fl.client.ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_numpy_client( + fl.client.start_client( server_address="127.0.0.1:8080", - client=FlowerClient(), + client=FlowerClient().to_client(), ) diff --git a/e2e/fastai/driver.py b/e2e/fastai/driver.py index 2b1b35d9e89c..cc452ea523ca 100644 --- a/e2e/fastai/driver.py +++ b/e2e/fastai/driver.py @@ -1,6 +1,6 @@ import flwr as fl -hist = fl.driver.start_driver( +hist = fl.server.start_driver( server_address="0.0.0.0:9091", config=fl.server.ServerConfig(num_rounds=3), ) diff --git a/e2e/fastai/pyproject.toml b/e2e/fastai/pyproject.toml index 66fcb393d988..feed31f6d202 100644 --- a/e2e/fastai/pyproject.toml +++ b/e2e/fastai/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" name = "quickstart-fastai" version = "0.1.0" description = "Fastai Federated Learning E2E test with Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.10" diff --git a/e2e/jax/client.py b/e2e/jax/client.py index 495d6a671981..a4e4d1f55117 100644 --- a/e2e/jax/client.py +++ b/e2e/jax/client.py @@ -53,10 +53,10 @@ def evaluate( def client_fn(cid): return FlowerClient().to_client() -flower = fl.flower.Flower( +app = fl.client.ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=FlowerClient()) + fl.client.start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client()) diff --git a/e2e/jax/driver.py b/e2e/jax/driver.py index 2b1b35d9e89c..cc452ea523ca 100644 --- a/e2e/jax/driver.py +++ b/e2e/jax/driver.py @@ -1,6 +1,6 @@ import flwr as fl -hist = fl.driver.start_driver( +hist = fl.server.start_driver( server_address="0.0.0.0:9091", config=fl.server.ServerConfig(num_rounds=3), ) diff --git a/e2e/jax/pyproject.toml b/e2e/jax/pyproject.toml index fde3c32608ca..9a4af5dee59a 100644 --- a/e2e/jax/pyproject.toml +++ b/e2e/jax/pyproject.toml @@ -2,13 +2,13 @@ name = "jax_example" version = "0.1.0" description = "JAX example training a linear regression model with federated learning" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = "^3.8" flwr = { path = "../../", develop = true, extras = ["simulation"] } -jax = "^0.4.0" -jaxlib = "^0.4.0" +jax = "==0.4.13" +jaxlib = "==0.4.13" scikit-learn = "^1.1.1" numpy = "^1.21.4" diff --git a/e2e/mxnet/.gitignore b/e2e/mxnet/.gitignore deleted file mode 100644 index 10d00b5797e2..000000000000 --- a/e2e/mxnet/.gitignore +++ /dev/null @@ -1 +0,0 @@ -*.gz diff --git a/e2e/mxnet/README.md b/e2e/mxnet/README.md deleted file mode 100644 index 3fa76bac5ce0..000000000000 --- a/e2e/mxnet/README.md +++ /dev/null @@ -1,5 +0,0 @@ -# Flower with MXNet testing - -This directory is used for testing Flower with MXNet by using a simple NN with MNIST data. - -It uses the `FedAvg` strategy. \ No newline at end of file diff --git a/e2e/mxnet/client.py b/e2e/mxnet/client.py deleted file mode 100644 index 2f0b714e708c..000000000000 --- a/e2e/mxnet/client.py +++ /dev/null @@ -1,141 +0,0 @@ -"""Flower client example using MXNet for MNIST classification. - -The code is generally adapted from: - -https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/image/mnist.html -""" - - -import flwr as fl -import numpy as np -import mxnet as mx -from mxnet import nd -from mxnet import gluon -from mxnet.gluon import nn -from mxnet import autograd as ag -import mxnet.ndarray as F - -SUBSET_SIZE = 50 - -# Fixing the random seed -mx.random.seed(42) - -# Setup context to GPU or CPU -DEVICE = [mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()] - -def model(): - net = nn.Sequential() - net.add(nn.Dense(256, activation="relu")) - net.add(nn.Dense(64, activation="relu")) - net.add(nn.Dense(10)) - net.collect_params().initialize() - return net - -def load_data(): - print("Download Dataset") - mnist = mx.test_utils.get_mnist() - batch_size = 100 - train_data = mx.io.NDArrayIter( - mnist["train_data"][:SUBSET_SIZE], mnist["train_label"][:SUBSET_SIZE], batch_size, shuffle=True - ) - val_data = mx.io.NDArrayIter(mnist["test_data"][:10], mnist["test_label"][:10], batch_size) - return train_data, val_data - - -def train(net, train_data, epoch): - trainer = gluon.Trainer(net.collect_params(), "sgd", {"learning_rate": 0.01}) - accuracy_metric = mx.metric.Accuracy() - loss_metric = mx.metric.CrossEntropy() - metrics = mx.metric.CompositeEvalMetric() - for child_metric in [accuracy_metric, loss_metric]: - metrics.add(child_metric) - softmax_cross_entropy_loss = gluon.loss.SoftmaxCrossEntropyLoss() - for i in range(epoch): - train_data.reset() - num_examples = 0 - for batch in train_data: - data = gluon.utils.split_and_load( - batch.data[0], ctx_list=DEVICE, batch_axis=0 - ) - label = gluon.utils.split_and_load( - batch.label[0], ctx_list=DEVICE, batch_axis=0 - ) - outputs = [] - with ag.record(): - for x, y in zip(data, label): - z = net(x) - loss = softmax_cross_entropy_loss(z, y) - loss.backward() - outputs.append(z.softmax()) - num_examples += len(x) - metrics.update(label, outputs) - trainer.step(batch.data[0].shape[0]) - trainings_metric = metrics.get_name_value() - print("Accuracy & loss at epoch %d: %s" % (i, trainings_metric)) - return trainings_metric, num_examples - - -def test(net, val_data): - accuracy_metric = mx.metric.Accuracy() - loss_metric = mx.metric.CrossEntropy() - metrics = mx.metric.CompositeEvalMetric() - for child_metric in [accuracy_metric, loss_metric]: - metrics.add(child_metric) - val_data.reset() - num_examples = 0 - for batch in val_data: - data = gluon.utils.split_and_load(batch.data[0], ctx_list=DEVICE, batch_axis=0) - label = gluon.utils.split_and_load( - batch.label[0], ctx_list=DEVICE, batch_axis=0 - ) - outputs = [] - for x in data: - outputs.append(net(x).softmax()) - num_examples += len(x) - metrics.update(label, outputs) - metrics.update(label, outputs) - return metrics.get_name_value(), num_examples - -train_data, val_data = load_data() - -model = model() -init = nd.random.uniform(shape=(2, 784)) -model(init) - -# Flower Client -class FlowerClient(fl.client.NumPyClient): - def get_parameters(self, config): - param = [] - for val in model.collect_params(".*weight").values(): - p = val.data() - param.append(p.asnumpy()) - return param - - def set_parameters(self, parameters): - params = zip(model.collect_params(".*weight").keys(), parameters) - for key, value in params: - model.collect_params().setattr(key, value) - - def fit(self, parameters, config): - self.set_parameters(parameters) - [accuracy, loss], num_examples = train(model, train_data, epoch=2) - results = {"accuracy": float(accuracy[1]), "loss": float(loss[1])} - return self.get_parameters(config={}), num_examples, results - - def evaluate(self, parameters, config): - self.set_parameters(parameters) - [accuracy, loss], num_examples = test(model, val_data) - print("Evaluation accuracy & loss", accuracy, loss) - return float(loss[1]), num_examples, {"accuracy": float(accuracy[1])} - - -def client_fn(cid): - return FlowerClient().to_client() - -flower = fl.flower.Flower( - client_fn=client_fn, -) - -if __name__ == "__main__": - # Start Flower client - fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=FlowerClient()) diff --git a/e2e/mxnet/driver.py b/e2e/mxnet/driver.py deleted file mode 100644 index 2b1b35d9e89c..000000000000 --- a/e2e/mxnet/driver.py +++ /dev/null @@ -1,7 +0,0 @@ -import flwr as fl - -hist = fl.driver.start_driver( - server_address="0.0.0.0:9091", - config=fl.server.ServerConfig(num_rounds=3), -) -assert (hist.losses_distributed[0][1] / hist.losses_distributed[-1][1]) >= 1 diff --git a/e2e/mxnet/pyproject.toml b/e2e/mxnet/pyproject.toml deleted file mode 100644 index 71bd0e6374bd..000000000000 --- a/e2e/mxnet/pyproject.toml +++ /dev/null @@ -1,15 +0,0 @@ -[build-system] -requires = ["poetry-core>=1.4.0"] -build-backend = "poetry.core.masonry.api" - -[tool.poetry] -name = "mxnet_example" -version = "0.1.0" -description = "MXNet example with MNIST and CNN" -authors = ["The Flower Authors "] - -[tool.poetry.dependencies] -python = "^3.8" -flwr = { path = "../../", develop = true, extras = ["simulation"] } -mxnet = "^1.7.0" -numpy = "1.23.1" diff --git a/e2e/mxnet/simulation.py b/e2e/mxnet/simulation.py deleted file mode 100644 index 5f0e5334bd08..000000000000 --- a/e2e/mxnet/simulation.py +++ /dev/null @@ -1,11 +0,0 @@ -import flwr as fl - -from client import client_fn - -hist = fl.simulation.start_simulation( - client_fn=client_fn, - num_clients=2, - config=fl.server.ServerConfig(num_rounds=3), -) - -assert hist.losses_distributed[-1][1] == 0 or (hist.losses_distributed[0][1] / hist.losses_distributed[-1][1]) >= 0.98 diff --git a/e2e/opacus/client.py b/e2e/opacus/client.py index 2e5c363381fa..00437a31233c 100644 --- a/e2e/opacus/client.py +++ b/e2e/opacus/client.py @@ -137,12 +137,12 @@ def client_fn(cid): model = Net() return FlowerClient(model).to_client() -flower = fl.flower.Flower( +app = fl.client.ClientApp( client_fn=client_fn, ) if __name__ == "__main__": - fl.client.start_numpy_client( + fl.client.start_client( server_address="127.0.0.1:8080", - client=FlowerClient(model) + client=FlowerClient(model).to_client() ) diff --git a/e2e/opacus/driver.py b/e2e/opacus/driver.py index 5a0309914ee0..75acd9ccea24 100644 --- a/e2e/opacus/driver.py +++ b/e2e/opacus/driver.py @@ -1,6 +1,6 @@ import flwr as fl -hist = fl.driver.start_driver( +hist = fl.server.start_driver( server_address="0.0.0.0:9091", config=fl.server.ServerConfig(num_rounds=3), ) diff --git a/e2e/opacus/pyproject.toml b/e2e/opacus/pyproject.toml index 1aee7c6ec6d9..ab4a727cc00b 100644 --- a/e2e/opacus/pyproject.toml +++ b/e2e/opacus/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" name = "opacus_e2e" version = "0.1.0" description = "Opacus E2E testing" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = "^3.8" diff --git a/e2e/pandas/client.py b/e2e/pandas/client.py index 5b8670091cb3..0ecd75df3ae8 100644 --- a/e2e/pandas/client.py +++ b/e2e/pandas/client.py @@ -1,4 +1,3 @@ -import warnings from typing import Dict, List, Tuple import numpy as np @@ -36,13 +35,13 @@ def fit( def client_fn(cid): return FlowerClient().to_client() -flower = fl.flower.Flower( +app = fl.client.ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_numpy_client( + fl.client.start_client( server_address="127.0.0.1:8080", - client=FlowerClient(), + client=FlowerClient().to_client(), ) diff --git a/e2e/pandas/driver.py b/e2e/pandas/driver.py index b33e1e54f4a0..f5dc74c9f3f8 100644 --- a/e2e/pandas/driver.py +++ b/e2e/pandas/driver.py @@ -3,7 +3,7 @@ from strategy import FedAnalytics # Start Flower server -hist = fl.driver.start_driver( +hist = fl.server.start_driver( server_address="0.0.0.0:9091", config=fl.server.ServerConfig(num_rounds=1), strategy=FedAnalytics(), diff --git a/e2e/pandas/pyproject.toml b/e2e/pandas/pyproject.toml index b90037a31068..416dfeec3460 100644 --- a/e2e/pandas/pyproject.toml +++ b/e2e/pandas/pyproject.toml @@ -7,7 +7,7 @@ name = "quickstart-pandas" version = "0.1.0" description = "Pandas Federated Analytics Quickstart with Flower" authors = ["Ragy Haddad "] -maintainers = ["The Flower Authors "] +maintainers = ["The Flower Authors "] [tool.poetry.dependencies] python = "^3.8" diff --git a/e2e/pandas/simulation.py b/e2e/pandas/simulation.py index 91af84062712..b548b5ebb760 100644 --- a/e2e/pandas/simulation.py +++ b/e2e/pandas/simulation.py @@ -1,12 +1,8 @@ import flwr as fl -from client import FlowerClient +from client import client_fn from strategy import FedAnalytics -def client_fn(cid): - _ = cid - return FlowerClient() - hist = fl.simulation.start_simulation( client_fn=client_fn, num_clients=2, diff --git a/e2e/pytorch-lightning/client.py b/e2e/pytorch-lightning/client.py index 71b178eca8c3..fde550e31c08 100644 --- a/e2e/pytorch-lightning/client.py +++ b/e2e/pytorch-lightning/client.py @@ -55,7 +55,7 @@ def client_fn(cid): # Flower client return FlowerClient(model, train_loader, val_loader, test_loader).to_client() -flower = fl.flower.Flower( +app = fl.client.ClientApp( client_fn=client_fn, ) @@ -65,8 +65,8 @@ def main() -> None: train_loader, val_loader, test_loader = mnist.load_data() # Flower client - client = FlowerClient(model, train_loader, val_loader, test_loader) - fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=client) + client = FlowerClient(model, train_loader, val_loader, test_loader).to_client() + fl.client.start_client(server_address="127.0.0.1:8080", client=client) if __name__ == "__main__": diff --git a/e2e/pytorch-lightning/driver.py b/e2e/pytorch-lightning/driver.py index 2b1b35d9e89c..cc452ea523ca 100644 --- a/e2e/pytorch-lightning/driver.py +++ b/e2e/pytorch-lightning/driver.py @@ -1,6 +1,6 @@ import flwr as fl -hist = fl.driver.start_driver( +hist = fl.server.start_driver( server_address="0.0.0.0:9091", config=fl.server.ServerConfig(num_rounds=3), ) diff --git a/e2e/pytorch-lightning/pyproject.toml b/e2e/pytorch-lightning/pyproject.toml index e79eb72a56df..88cddddf500f 100644 --- a/e2e/pytorch-lightning/pyproject.toml +++ b/e2e/pytorch-lightning/pyproject.toml @@ -6,10 +6,10 @@ build-backend = "poetry.masonry.api" name = "quickstart-pytorch-lightning" version = "0.1.0" description = "Federated Learning E2E test with Flower and PyTorch Lightning" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = "^3.8" flwr = { path = "../../", develop = true, extras = ["simulation"] } -pytorch-lightning = "1.6.0" +pytorch-lightning = "2.1.3" torchvision = "0.14.1" diff --git a/e2e/pytorch/client.py b/e2e/pytorch/client.py index d180ad5d4eca..1fd07763148e 100644 --- a/e2e/pytorch/client.py +++ b/e2e/pytorch/client.py @@ -11,6 +11,7 @@ from tqdm import tqdm import flwr as fl +from flwr.common import ConfigsRecord # ############################################################################# # 1. Regular PyTorch pipeline: nn.Module, train, test, and DataLoader @@ -95,14 +96,15 @@ def get_parameters(self, config): def _record_timestamp_to_state(self): """Record timestamp to client's state.""" t_stamp = datetime.now().timestamp() - if STATE_VAR in self.state.state: - self.state.state[STATE_VAR] += f",{t_stamp}" - else: - self.state.state[STATE_VAR] = str(t_stamp) + value = str(t_stamp) + if STATE_VAR in self.context.state.configs_records.keys(): + value = self.context.state.configs_records[STATE_VAR][STATE_VAR] # type: ignore + value += f",{t_stamp}" + + self.context.state.configs_records[STATE_VAR] = ConfigsRecord({STATE_VAR: value}) def _retrieve_timestamp_from_state(self): - return self.state.state[STATE_VAR] - + return self.context.state.configs_records[STATE_VAR][STATE_VAR] def fit(self, parameters, config): set_parameters(net, parameters) train(net, trainloader, epochs=1) @@ -124,14 +126,14 @@ def set_parameters(model, parameters): def client_fn(cid): return FlowerClient().to_client() -flower = fl.flower.Flower( +app = fl.client.ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_numpy_client( + fl.client.start_client( server_address="127.0.0.1:8080", - client=FlowerClient(), + client=FlowerClient().to_client(), ) diff --git a/e2e/pytorch/driver.py b/e2e/pytorch/driver.py index ca860ea47b2d..2ea4de69a62b 100644 --- a/e2e/pytorch/driver.py +++ b/e2e/pytorch/driver.py @@ -18,7 +18,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=weighted_average) # Start Flower server -hist = fl.driver.start_driver( +hist = fl.server.start_driver( server_address="0.0.0.0:9091", config=fl.server.ServerConfig(num_rounds=3), strategy=strategy, diff --git a/e2e/pytorch/pyproject.toml b/e2e/pytorch/pyproject.toml index 4aa326330c62..e538f1437df6 100644 --- a/e2e/pytorch/pyproject.toml +++ b/e2e/pytorch/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" name = "quickstart-pytorch" version = "0.1.0" description = "PyTorch Federated Learning Quickstart with Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = "^3.8" diff --git a/e2e/scikit-learn/client.py b/e2e/scikit-learn/client.py index fdca96c1697a..e073d3cb2748 100644 --- a/e2e/scikit-learn/client.py +++ b/e2e/scikit-learn/client.py @@ -46,10 +46,10 @@ def evaluate(self, parameters, config): # type: ignore def client_fn(cid): return FlowerClient().to_client() -flower = fl.flower.Flower( +app = fl.client.ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_numpy_client(server_address="0.0.0.0:8080", client=FlowerClient()) + fl.client.start_client(server_address="0.0.0.0:8080", client=FlowerClient().to_client()) diff --git a/e2e/scikit-learn/driver.py b/e2e/scikit-learn/driver.py index 032a2f7a0dc6..29051d02c6b6 100644 --- a/e2e/scikit-learn/driver.py +++ b/e2e/scikit-learn/driver.py @@ -36,7 +36,7 @@ def evaluate(server_round, parameters: fl.common.NDArrays, config): evaluate_fn=get_evaluate_fn(model), on_fit_config_fn=fit_round, ) - hist = fl.driver.start_driver( + hist = fl.server.start_driver( server_address="0.0.0.0:9091", strategy=strategy, config=fl.server.ServerConfig(num_rounds=3), diff --git a/e2e/scikit-learn/pyproject.toml b/e2e/scikit-learn/pyproject.toml index 372ae1218bbe..50c07d31add7 100644 --- a/e2e/scikit-learn/pyproject.toml +++ b/e2e/scikit-learn/pyproject.toml @@ -7,8 +7,8 @@ name = "sklearn-mnist" version = "0.1.0" description = "Federated learning with scikit-learn and Flower" authors = [ - "The Flower Authors ", - "Kaushik Amar Das " + "The Flower Authors ", + "Kaushik Amar Das ", ] [tool.poetry.dependencies] diff --git a/e2e/scikit-learn/utils.py b/e2e/scikit-learn/utils.py index d6804fdcce12..2b8dcf8655ee 100644 --- a/e2e/scikit-learn/utils.py +++ b/e2e/scikit-learn/utils.py @@ -10,7 +10,7 @@ def get_model_parameters(model: LogisticRegression) -> LogRegParams: - """Returns the paramters of a sklearn LogisticRegression model.""" + """Returns the parameters of a sklearn LogisticRegression model.""" if model.fit_intercept: params = [ model.coef_, diff --git a/e2e/strategies/client.py b/e2e/strategies/client.py index eb4598cb5439..3b49f770dc6b 100644 --- a/e2e/strategies/client.py +++ b/e2e/strategies/client.py @@ -47,11 +47,11 @@ def client_fn(cid): return FlowerClient().to_client() -flower = fl.flower.Flower( +app = fl.client.ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=FlowerClient()) + fl.client.start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client()) diff --git a/e2e/strategies/pyproject.toml b/e2e/strategies/pyproject.toml index 611325ce721b..5cc74b20fa24 100644 --- a/e2e/strategies/pyproject.toml +++ b/e2e/strategies/pyproject.toml @@ -6,9 +6,10 @@ build-backend = "poetry.core.masonry.api" name = "quickstart_tensorflow" version = "0.1.0" description = "Keras Federated Learning Quickstart with Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" flwr = { path = "../../", develop = true, extras = ["simulation"] } tensorflow-cpu = "^2.9.1, !=2.11.1" +tensorflow-io-gcs-filesystem = "<0.35.0" diff --git a/e2e/tabnet/client.py b/e2e/tabnet/client.py index 3c10df0c79f1..0290ba4629de 100644 --- a/e2e/tabnet/client.py +++ b/e2e/tabnet/client.py @@ -81,10 +81,10 @@ def evaluate(self, parameters, config): def client_fn(cid): return FlowerClient().to_client() -flower = fl.flower.Flower( +app = fl.client.ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=FlowerClient()) + fl.client.start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client()) diff --git a/e2e/tabnet/driver.py b/e2e/tabnet/driver.py index 2b1b35d9e89c..cc452ea523ca 100644 --- a/e2e/tabnet/driver.py +++ b/e2e/tabnet/driver.py @@ -1,6 +1,6 @@ import flwr as fl -hist = fl.driver.start_driver( +hist = fl.server.start_driver( server_address="0.0.0.0:9091", config=fl.server.ServerConfig(num_rounds=3), ) diff --git a/e2e/tabnet/pyproject.toml b/e2e/tabnet/pyproject.toml index 644f422d6762..b1abf382a24a 100644 --- a/e2e/tabnet/pyproject.toml +++ b/e2e/tabnet/pyproject.toml @@ -6,12 +6,13 @@ build-backend = "poetry.core.masonry.api" name = "quickstart-tabnet" version = "0.1.0" description = "Tabnet Federated Learning Quickstart with Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" flwr = { path = "../../", develop = true, extras = ["simulation"] } -tensorflow-cpu = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\""} -tensorflow-macos = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\""} +tensorflow-cpu = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\"" } +tensorflow-macos = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\"" } tensorflow_datasets = "4.9.2" +tensorflow-io-gcs-filesystem = "<0.35.0" tabnet = "0.1.6" diff --git a/e2e/tensorflow/client.py b/e2e/tensorflow/client.py index 4ad2d5ebda57..10ee91136241 100644 --- a/e2e/tensorflow/client.py +++ b/e2e/tensorflow/client.py @@ -34,10 +34,10 @@ def evaluate(self, parameters, config): def client_fn(cid): return FlowerClient().to_client() -flower = fl.flower.Flower( +app = fl.client.ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=FlowerClient()) + fl.client.start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client()) diff --git a/e2e/tensorflow/driver.py b/e2e/tensorflow/driver.py index ca860ea47b2d..2ea4de69a62b 100644 --- a/e2e/tensorflow/driver.py +++ b/e2e/tensorflow/driver.py @@ -18,7 +18,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=weighted_average) # Start Flower server -hist = fl.driver.start_driver( +hist = fl.server.start_driver( server_address="0.0.0.0:9091", config=fl.server.ServerConfig(num_rounds=3), strategy=strategy, diff --git a/e2e/tensorflow/pyproject.toml b/e2e/tensorflow/pyproject.toml index c66ffc30fdf0..a7dbfe2305db 100644 --- a/e2e/tensorflow/pyproject.toml +++ b/e2e/tensorflow/pyproject.toml @@ -6,9 +6,10 @@ build-backend = "poetry.core.masonry.api" name = "quickstart-tensorflow" version = "0.1.0" description = "Keras Federated Learning Quickstart with Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" flwr = { path = "../../", develop = true, extras = ["simulation"] } tensorflow-cpu = "^2.9.1, !=2.11.1" +tensorflow-io-gcs-filesystem = "<0.35.0" diff --git a/e2e/test_driver.sh b/e2e/test_driver.sh index ca54dbf4852f..3d4864a1b0fb 100755 --- a/e2e/test_driver.sh +++ b/e2e/test_driver.sh @@ -13,13 +13,34 @@ case "$1" in ;; esac -timeout 2m flower-server $server_arg & +case "$2" in + rest) + rest_arg="--rest" + server_address="http://localhost:9093" + db_arg="--database :flwr-in-memory-state:" + ;; + sqlite) + rest_arg="" + server_address="127.0.0.1:9092" + db_arg="--database $(date +%s).db" + ;; + *) + rest_arg="" + server_address="127.0.0.1:9092" + db_arg="--database :flwr-in-memory-state:" + ;; +esac + +timeout 2m flower-superlink $server_arg $db_arg $rest_arg & +sl_pid=$! sleep 3 -timeout 2m flower-client $client_arg --callable client:flower --server 127.0.0.1:9092 & +timeout 2m flower-client-app client:app $client_arg $rest_arg --server $server_address & +cl1_pid=$! sleep 3 -timeout 2m flower-client $client_arg --callable client:flower --server 127.0.0.1:9092 & +timeout 2m flower-client-app client:app $client_arg $rest_arg --server $server_address & +cl2_pid=$! sleep 3 timeout 2m python driver.py & @@ -29,7 +50,7 @@ wait $pid res=$? if [[ "$res" = "0" ]]; - then echo "Training worked correctly" && pkill flower-client && pkill flower-server; + then echo "Training worked correctly"; kill $cl1_pid; kill $cl2_pid; kill $sl_pid; else echo "Training had an issue" && exit 1; fi diff --git a/examples/advanced-pytorch/README.md b/examples/advanced-pytorch/README.md index db0245e41453..c1ba85b95879 100644 --- a/examples/advanced-pytorch/README.md +++ b/examples/advanced-pytorch/README.md @@ -1,6 +1,6 @@ # Advanced Flower Example (PyTorch) -This example demonstrates an advanced federated learning setup using Flower with PyTorch. It differs from the quickstart example in the following ways: +This example demonstrates an advanced federated learning setup using Flower with PyTorch. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) and it differs from the quickstart example in the following ways: - 10 clients (instead of just 2) - Each client holds a local dataset of 5000 training examples and 1000 test examples (note that using the `run.sh` script will only select 10 data samples by default, as the `--toy` argument is set). @@ -59,12 +59,13 @@ pip install -r requirements.txt The included `run.sh` will start the Flower server (using `server.py`), sleep for 2 seconds to ensure that the server is up, and then start 10 Flower clients (using `client.py`) with only a small subset of the data (in order to run on any machine), -but this can be changed by removing the `--toy True` argument in the script. You can simply start everything in a terminal as follows: +but this can be changed by removing the `--toy` argument in the script. You can simply start everything in a terminal as follows: ```shell -poetry run ./run.sh +# After activating your environment +./run.sh ``` The `run.sh` script starts processes in the background so that you don't have to open eleven terminal windows. If you experiment with the code example and something goes wrong, simply using `CTRL + C` on Linux (or `CMD + C` on macOS) wouldn't normally kill all these processes, which is why the script ends with `trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM EXIT` and `wait`. This simply allows you to stop the experiment using `CTRL + C` (or `CMD + C`). If you change the script and anything goes wrong you can still use `killall python` (or `killall python3`) to kill all background processes (or a more specific command if you have other Python processes running that you don't want to kill). -You can also manually run `poetry run python3 server.py` and `poetry run python3 client.py` for as many clients as you want but you have to make sure that each command is ran in a different terminal window (or a different computer on the network). +You can also manually run `python3 server.py` and `python3 client.py --client-id ` for as many clients as you want but you have to make sure that each command is run in a different terminal window (or a different computer on the network). In addition, you can make your clients use either `EfficienNet` (default) or `AlexNet` (but all clients in the experiment should use the same). Switch between models using the `--model` flag when launching `client.py` and `server.py`. diff --git a/examples/advanced-pytorch/client.py b/examples/advanced-pytorch/client.py index f9ffb6181fd8..d4c8abe3d404 100644 --- a/examples/advanced-pytorch/client.py +++ b/examples/advanced-pytorch/client.py @@ -1,11 +1,11 @@ import utils from torch.utils.data import DataLoader -import torchvision.datasets import torch import flwr as fl import argparse from collections import OrderedDict import warnings +import datasets warnings.filterwarnings("ignore") @@ -13,47 +13,49 @@ class CifarClient(fl.client.NumPyClient): def __init__( self, - trainset: torchvision.datasets, - testset: torchvision.datasets, - device: str, + trainset: datasets.Dataset, + testset: datasets.Dataset, + device: torch.device, + model_str: str, validation_split: int = 0.1, ): self.device = device self.trainset = trainset self.testset = testset self.validation_split = validation_split + if model_str == "alexnet": + self.model = utils.load_alexnet(classes=10) + else: + self.model = utils.load_efficientnet(classes=10) def set_parameters(self, parameters): - """Loads a efficientnet model and replaces it parameters with the ones given.""" - model = utils.load_efficientnet(classes=10) - params_dict = zip(model.state_dict().keys(), parameters) + """Loads a alexnet or efficientnet model and replaces it parameters with the + ones given.""" + + params_dict = zip(self.model.state_dict().keys(), parameters) state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) - model.load_state_dict(state_dict, strict=True) - return model + self.model.load_state_dict(state_dict, strict=True) def fit(self, parameters, config): """Train parameters on the locally held training set.""" # Update local model parameters - model = self.set_parameters(parameters) + self.set_parameters(parameters) # Get hyperparameters for this round batch_size: int = config["batch_size"] epochs: int = config["local_epochs"] - n_valset = int(len(self.trainset) * self.validation_split) - - valset = torch.utils.data.Subset(self.trainset, range(0, n_valset)) - trainset = torch.utils.data.Subset( - self.trainset, range(n_valset, len(self.trainset)) - ) + train_valid = self.trainset.train_test_split(self.validation_split) + trainset = train_valid["train"] + valset = train_valid["test"] - trainLoader = DataLoader(trainset, batch_size=batch_size, shuffle=True) - valLoader = DataLoader(valset, batch_size=batch_size) + train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True) + val_loader = DataLoader(valset, batch_size=batch_size) - results = utils.train(model, trainLoader, valLoader, epochs, self.device) + results = utils.train(self.model, train_loader, val_loader, epochs, self.device) - parameters_prime = utils.get_model_params(model) + parameters_prime = utils.get_model_params(self.model) num_examples_train = len(trainset) return parameters_prime, num_examples_train, results @@ -61,7 +63,7 @@ def fit(self, parameters, config): def evaluate(self, parameters, config): """Evaluate parameters on the locally held test set.""" # Update local model parameters - model = self.set_parameters(parameters) + self.set_parameters(parameters) # Get config values steps: int = config["val_steps"] @@ -69,17 +71,17 @@ def evaluate(self, parameters, config): # Evaluate global model parameters on the local test data and return results testloader = DataLoader(self.testset, batch_size=16) - loss, accuracy = utils.test(model, testloader, steps, self.device) + loss, accuracy = utils.test(self.model, testloader, steps, self.device) return float(loss), len(self.testset), {"accuracy": float(accuracy)} -def client_dry_run(device: str = "cpu"): +def client_dry_run(device: torch.device = "cpu"): """Weak tests to check whether all client methods are working as expected.""" model = utils.load_efficientnet(classes=10) trainset, testset = utils.load_partition(0) - trainset = torch.utils.data.Subset(trainset, range(10)) - testset = torch.utils.data.Subset(testset, range(10)) + trainset = trainset.select(range(10)) + testset = testset.select(range(10)) client = CifarClient(trainset, testset, device) client.fit( utils.get_model_params(model), @@ -102,7 +104,7 @@ def main() -> None: help="Do a dry-run to check the client", ) parser.add_argument( - "--partition", + "--client-id", type=int, default=0, choices=range(0, 10), @@ -112,9 +114,7 @@ def main() -> None: ) parser.add_argument( "--toy", - type=bool, - default=False, - required=False, + action="store_true", help="Set to true to quicky run the client using only 10 datasamples. \ Useful for testing purposes. Default: False", ) @@ -125,6 +125,14 @@ def main() -> None: required=False, help="Set to true to use GPU. Default: False", ) + parser.add_argument( + "--model", + type=str, + default="efficientnet", + choices=["efficientnet", "alexnet"], + help="Use either Efficientnet or Alexnet models. \ + If you want to achieve differential privacy, please use the Alexnet model", + ) args = parser.parse_args() @@ -136,16 +144,14 @@ def main() -> None: client_dry_run(device) else: # Load a subset of CIFAR-10 to simulate the local data partition - trainset, testset = utils.load_partition(args.partition) + trainset, testset = utils.load_partition(args.client_id) if args.toy: - trainset = torch.utils.data.Subset(trainset, range(10)) - testset = torch.utils.data.Subset(testset, range(10)) - + trainset = trainset.select(range(10)) + testset = testset.select(range(10)) # Start Flower client - client = CifarClient(trainset, testset, device) - - fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=client) + client = CifarClient(trainset, testset, device, args.model).to_client() + fl.client.start_client(server_address="127.0.0.1:8080", client=client) if __name__ == "__main__": diff --git a/examples/advanced-pytorch/pyproject.toml b/examples/advanced-pytorch/pyproject.toml index a12f3c47de70..b846a6054cc8 100644 --- a/examples/advanced-pytorch/pyproject.toml +++ b/examples/advanced-pytorch/pyproject.toml @@ -7,13 +7,14 @@ name = "advanced-pytorch" version = "0.1.0" description = "Advanced Flower/PyTorch Example" authors = [ - "The Flower Authors ", - "Kaushik Amar Das " + "The Flower Authors ", + "Kaushik Amar Das ", ] [tool.poetry.dependencies] python = ">=3.8,<3.11" flwr = ">=1.0,<2.0" +flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" } torch = "1.13.1" torchvision = "0.14.1" validators = "0.18.2" diff --git a/examples/advanced-pytorch/requirements.txt b/examples/advanced-pytorch/requirements.txt index ba7b284df90e..f4d6a0774162 100644 --- a/examples/advanced-pytorch/requirements.txt +++ b/examples/advanced-pytorch/requirements.txt @@ -1,4 +1,5 @@ flwr>=1.0, <2.0 +flwr-datasets[vision]>=0.0.2, <1.0.0 torch==1.13.1 torchvision==0.14.1 validators==0.18.2 diff --git a/examples/advanced-pytorch/run.sh b/examples/advanced-pytorch/run.sh index 212285f504f9..c3d52491b987 100755 --- a/examples/advanced-pytorch/run.sh +++ b/examples/advanced-pytorch/run.sh @@ -2,20 +2,12 @@ set -e cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/ -# Download the CIFAR-10 dataset -python -c "from torchvision.datasets import CIFAR10; CIFAR10('./dataset', download=True)" - -# Download the EfficientNetB0 model -python -c "import torch; torch.hub.load( \ - 'NVIDIA/DeepLearningExamples:torchhub', \ - 'nvidia_efficientnet_b0', pretrained=True)" - -python server.py & -sleep 3 # Sleep for 3s to give the server enough time to start +python server.py --toy & +sleep 10 # Sleep for 10s to give the server enough time to start and dowload the dataset for i in `seq 0 9`; do echo "Starting client $i" - python client.py --partition=${i} --toy True & + python client.py --client-id=${i} --toy & done # Enable CTRL+C to stop all background processes diff --git a/examples/advanced-pytorch/server.py b/examples/advanced-pytorch/server.py index 8343e62da69f..489694ab1ea1 100644 --- a/examples/advanced-pytorch/server.py +++ b/examples/advanced-pytorch/server.py @@ -10,6 +10,8 @@ import warnings +from flwr_datasets import FederatedDataset + warnings.filterwarnings("ignore") @@ -39,18 +41,13 @@ def evaluate_config(server_round: int): def get_evaluate_fn(model: torch.nn.Module, toy: bool): """Return an evaluation function for server-side evaluation.""" - # Load data and model here to avoid the overhead of doing it in `evaluate` itself - trainset, _, _ = utils.load_data() - - n_train = len(trainset) + # Load data here to avoid the overhead of doing it in `evaluate` itself + centralized_data = utils.load_centralized_data() if toy: # use only 10 samples as validation set - valset = torch.utils.data.Subset(trainset, range(n_train - 10, n_train)) - else: - # Use the last 5k training examples as a validation set - valset = torch.utils.data.Subset(trainset, range(n_train - 5000, n_train)) + centralized_data = centralized_data.select(range(10)) - valLoader = DataLoader(valset, batch_size=16) + val_loader = DataLoader(centralized_data, batch_size=16) # The `evaluate` function will be called after every round def evaluate( @@ -63,7 +60,7 @@ def evaluate( state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) model.load_state_dict(state_dict, strict=True) - loss, accuracy = utils.test(model, valLoader) + loss, accuracy = utils.test(model, val_loader) return loss, {"accuracy": accuracy} return evaluate @@ -79,23 +76,32 @@ def main(): parser = argparse.ArgumentParser(description="Flower") parser.add_argument( "--toy", - type=bool, - default=False, - required=False, + action="store_true", help="Set to true to use only 10 datasamples for validation. \ Useful for testing purposes. Default: False", ) + parser.add_argument( + "--model", + type=str, + default="efficientnet", + choices=["efficientnet", "alexnet"], + help="Use either Efficientnet or Alexnet models. \ + If you want to achieve differential privacy, please use the Alexnet model", + ) args = parser.parse_args() - model = utils.load_efficientnet(classes=10) + if args.model == "alexnet": + model = utils.load_alexnet(classes=10) + else: + model = utils.load_efficientnet(classes=10) model_parameters = [val.cpu().numpy() for _, val in model.state_dict().items()] # Create strategy strategy = fl.server.strategy.FedAvg( - fraction_fit=0.2, - fraction_evaluate=0.2, + fraction_fit=1.0, + fraction_evaluate=1.0, min_fit_clients=2, min_evaluate_clients=2, min_available_clients=10, diff --git a/examples/advanced-pytorch/utils.py b/examples/advanced-pytorch/utils.py index 8788ead90dee..fd9dab19a70d 100644 --- a/examples/advanced-pytorch/utils.py +++ b/examples/advanced-pytorch/utils.py @@ -1,59 +1,59 @@ import torch -import torchvision.transforms as transforms -from torchvision.datasets import CIFAR10 - +from torchvision.transforms import Compose, ToTensor, Normalize, Resize, CenterCrop +from torchvision.models import efficientnet_b0, AlexNet import warnings -warnings.filterwarnings("ignore") +from flwr_datasets import FederatedDataset -# DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +warnings.filterwarnings("ignore") -def load_data(): - """Load CIFAR-10 (training and test set).""" - transform = transforms.Compose( - [ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), - ] - ) - trainset = CIFAR10("./dataset", train=True, download=True, transform=transform) - testset = CIFAR10("./dataset", train=False, download=True, transform=transform) +def load_partition(partition_id, toy: bool = False): + """Load partition CIFAR10 data.""" + fds = FederatedDataset(dataset="cifar10", partitioners={"train": 10}) + partition = fds.load_partition(partition_id) + # Divide data on each node: 80% train, 20% test + partition_train_test = partition.train_test_split(test_size=0.2) + partition_train_test = partition_train_test.with_transform(apply_transforms) + return partition_train_test["train"], partition_train_test["test"] - num_examples = {"trainset": len(trainset), "testset": len(testset)} - return trainset, testset, num_examples +def load_centralized_data(): + fds = FederatedDataset(dataset="cifar10", partitioners={"train": 10}) + centralized_data = fds.load_split("test") + centralized_data = centralized_data.with_transform(apply_transforms) + return centralized_data -def load_partition(idx: int): - """Load 1/10th of the training and test data to simulate a partition.""" - assert idx in range(10) - trainset, testset, num_examples = load_data() - n_train = int(num_examples["trainset"] / 10) - n_test = int(num_examples["testset"] / 10) - train_parition = torch.utils.data.Subset( - trainset, range(idx * n_train, (idx + 1) * n_train) - ) - test_parition = torch.utils.data.Subset( - testset, range(idx * n_test, (idx + 1) * n_test) +def apply_transforms(batch): + """Apply transforms to the partition from FederatedDataset.""" + pytorch_transforms = Compose( + [ + Resize(256), + CenterCrop(224), + ToTensor(), + Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] ) - return (train_parition, test_parition) + batch["img"] = [pytorch_transforms(img) for img in batch["img"]] + return batch -def train(net, trainloader, valloader, epochs, device: str = "cpu"): +def train( + net, trainloader, valloader, epochs, device: torch.device = torch.device("cpu") +): """Train the network on the training set.""" print("Starting training...") net.to(device) # move model to GPU if available criterion = torch.nn.CrossEntropyLoss().to(device) optimizer = torch.optim.SGD( - net.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4 + net.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4 ) net.train() for _ in range(epochs): - for images, labels in trainloader: + for batch in trainloader: + images, labels = batch["img"], batch["label"] images, labels = images.to(device), labels.to(device) optimizer.zero_grad() loss = criterion(net(images), labels) @@ -74,7 +74,9 @@ def train(net, trainloader, valloader, epochs, device: str = "cpu"): return results -def test(net, testloader, steps: int = None, device: str = "cpu"): +def test( + net, testloader, steps: int = None, device: torch.device = torch.device("cpu") +): """Validate the network on the entire test set.""" print("Starting evalutation...") net.to(device) # move model to GPU if available @@ -82,7 +84,8 @@ def test(net, testloader, steps: int = None, device: str = "cpu"): correct, loss = 0, 0.0 net.eval() with torch.no_grad(): - for batch_idx, (images, labels) in enumerate(testloader): + for batch_idx, batch in enumerate(testloader): + images, labels = batch["img"], batch["label"] images, labels = images.to(device), labels.to(device) outputs = net(images) loss += criterion(outputs, labels).item() @@ -95,36 +98,21 @@ def test(net, testloader, steps: int = None, device: str = "cpu"): return loss, accuracy -def replace_classifying_layer(efficientnet_model, num_classes: int = 10): - """Replaces the final layer of the classifier.""" - num_features = efficientnet_model.classifier.fc.in_features - efficientnet_model.classifier.fc = torch.nn.Linear(num_features, num_classes) - - -def load_efficientnet(entrypoint: str = "nvidia_efficientnet_b0", classes: int = None): - """Loads pretrained efficientnet model from torch hub. Replaces final classifying - layer if classes is specified. - - Args: - entrypoint: EfficientNet model to download. - For supported entrypoints, please refer - https://pytorch.org/hub/nvidia_deeplearningexamples_efficientnet/ - classes: Number of classes in final classifying layer. Leave as None to get the downloaded - model untouched. - Returns: - EfficientNet Model - - Note: One alternative implementation can be found at https://github.com/lukemelas/EfficientNet-PyTorch - """ - efficientnet = torch.hub.load( - "NVIDIA/DeepLearningExamples:torchhub", entrypoint, pretrained=True - ) - - if classes is not None: - replace_classifying_layer(efficientnet, classes) +def load_efficientnet(classes: int = 10): + """Loads EfficienNetB0 from TorchVision.""" + efficientnet = efficientnet_b0(pretrained=True) + # Re-init output linear layer with the right number of classes + model_classes = efficientnet.classifier[1].in_features + if classes != model_classes: + efficientnet.classifier[1] = torch.nn.Linear(model_classes, classes) return efficientnet def get_model_params(model): """Returns a model's parameters.""" return [val.cpu().numpy() for _, val in model.state_dict().items()] + + +def load_alexnet(classes): + """Load AlexNet model from TorchVision.""" + return AlexNet(num_classes=classes) diff --git a/examples/advanced-tensorflow/README.md b/examples/advanced-tensorflow/README.md index 31bf5edb64c6..94707b5cbc98 100644 --- a/examples/advanced-tensorflow/README.md +++ b/examples/advanced-tensorflow/README.md @@ -1,9 +1,9 @@ # Advanced Flower Example (TensorFlow/Keras) -This example demonstrates an advanced federated learning setup using Flower with TensorFlow/Keras. It differs from the quickstart example in the following ways: +This example demonstrates an advanced federated learning setup using Flower with TensorFlow/Keras. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) and it differs from the quickstart example in the following ways: - 10 clients (instead of just 2) -- Each client holds a local dataset of 5000 training examples and 1000 test examples (note that by default only a small subset of this data is used when running the `run.sh` script) +- Each client holds a local dataset of 1/10 of the train datasets and 80% is training examples and 20% as test examples (note that by default only a small subset of this data is used when running the `run.sh` script) - Server-side model evaluation after parameter aggregation - Hyperparameter schedule using config functions - Custom return values @@ -57,10 +57,11 @@ pip install -r requirements.txt ## Run Federated Learning with TensorFlow/Keras and Flower -The included `run.sh` will call a script to generate certificates (which will be used by server and clients), start the Flower server (using `server.py`), sleep for 2 seconds to ensure the the server is up, and then start 10 Flower clients (using `client.py`). You can simply start everything in a terminal as follows: +The included `run.sh` will call a script to generate certificates (which will be used by server and clients), start the Flower server (using `server.py`), sleep for 10 seconds to ensure the server is up, and then start 10 Flower clients (using `client.py`). You can simply start everything in a terminal as follows: ```shell -poetry run ./run.sh +# Once you have activated your environment +./run.sh ``` The `run.sh` script starts processes in the background so that you don't have to open eleven terminal windows. If you experiment with the code example and something goes wrong, simply using `CTRL + C` on Linux (or `CMD + C` on macOS) wouldn't normally kill all these processes, which is why the script ends with `trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM EXIT` and `wait`. This simply allows you to stop the experiment using `CTRL + C` (or `CMD + C`). If you change the script and anything goes wrong you can still use `killall python` (or `killall python3`) to kill all background processes (or a more specific command if you have other Python processes running that you don't want to kill). diff --git a/examples/advanced-tensorflow/client.py b/examples/advanced-tensorflow/client.py index 1c0b61575635..17d1d2306270 100644 --- a/examples/advanced-tensorflow/client.py +++ b/examples/advanced-tensorflow/client.py @@ -6,6 +6,8 @@ import flwr as fl +from flwr_datasets import FederatedDataset + # Make TensorFlow logs less verbose os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" @@ -74,7 +76,7 @@ def main() -> None: # Parse command line argument `partition` parser = argparse.ArgumentParser(description="Flower") parser.add_argument( - "--partition", + "--client-id", type=int, default=0, choices=range(0, 10), @@ -84,9 +86,7 @@ def main() -> None: ) parser.add_argument( "--toy", - type=bool, - default=False, - required=False, + action="store_true", help="Set to true to quicky run the client using only 10 datasamples. " "Useful for testing purposes. Default: False", ) @@ -99,16 +99,16 @@ def main() -> None: model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"]) # Load a subset of CIFAR-10 to simulate the local data partition - (x_train, y_train), (x_test, y_test) = load_partition(args.partition) + x_train, y_train, x_test, y_test = load_partition(args.client_id) if args.toy: x_train, y_train = x_train[:10], y_train[:10] x_test, y_test = x_test[:10], y_test[:10] # Start Flower client - client = CifarClient(model, x_train, y_train, x_test, y_test) + client = CifarClient(model, x_train, y_train, x_test, y_test).to_client() - fl.client.start_numpy_client( + fl.client.start_client( server_address="127.0.0.1:8080", client=client, root_certificates=Path(".cache/certificates/ca.crt").read_bytes(), @@ -117,15 +117,16 @@ def main() -> None: def load_partition(idx: int): """Load 1/10th of the training and test data to simulate a partition.""" - assert idx in range(10) - (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data() - return ( - x_train[idx * 5000 : (idx + 1) * 5000], - y_train[idx * 5000 : (idx + 1) * 5000], - ), ( - x_test[idx * 1000 : (idx + 1) * 1000], - y_test[idx * 1000 : (idx + 1) * 1000], - ) + # Download and partition dataset + fds = FederatedDataset(dataset="cifar10", partitioners={"train": 10}) + partition = fds.load_partition(idx) + partition.set_format("numpy") + + # Divide data on each node: 80% train, 20% test + partition = partition.train_test_split(test_size=0.2) + x_train, y_train = partition["train"]["img"] / 255.0, partition["train"]["label"] + x_test, y_test = partition["test"]["img"] / 255.0, partition["test"]["label"] + return x_train, y_train, x_test, y_test if __name__ == "__main__": diff --git a/examples/advanced-tensorflow/pyproject.toml b/examples/advanced-tensorflow/pyproject.toml index 293ba64b3f43..02bd923129a4 100644 --- a/examples/advanced-tensorflow/pyproject.toml +++ b/examples/advanced-tensorflow/pyproject.toml @@ -6,10 +6,11 @@ build-backend = "poetry.core.masonry.api" name = "advanced-tensorflow" version = "0.1.0" description = "Advanced Flower/TensorFlow Example" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" flwr = ">=1.0,<2.0" -tensorflow-cpu = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\""} -tensorflow-macos = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\""} +flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" } +tensorflow-cpu = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\"" } +tensorflow-macos = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\"" } diff --git a/examples/advanced-tensorflow/requirements.txt b/examples/advanced-tensorflow/requirements.txt index 7a70c46a8128..0cb5fe8c07af 100644 --- a/examples/advanced-tensorflow/requirements.txt +++ b/examples/advanced-tensorflow/requirements.txt @@ -1,3 +1,4 @@ flwr>=1.0, <2.0 +flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" } tensorflow-cpu>=2.9.1, != 2.11.1 ; platform_machine == "x86_64" tensorflow-macos>=2.9.1, != 2.11.1 ; sys_platform == "darwin" and platform_machine == "arm64" diff --git a/examples/advanced-tensorflow/run.sh b/examples/advanced-tensorflow/run.sh index 8ddb6a252b52..4acef1371571 100755 --- a/examples/advanced-tensorflow/run.sh +++ b/examples/advanced-tensorflow/run.sh @@ -5,14 +5,11 @@ echo "Starting server" python server.py & -sleep 3 # Sleep for 3s to give the server enough time to start +sleep 10 # Sleep for 10s to give the server enough time to start and download the dataset -# Ensure that the Keras dataset used in client.py is already cached. -python -c "import tensorflow as tf; tf.keras.datasets.cifar10.load_data()" - -for i in `seq 0 9`; do +for i in $(seq 0 9); do echo "Starting client $i" - python client.py --partition=${i} --toy True & + python client.py --client-id=${i} --toy & done # This will allow you to use CTRL+C to stop all background processes diff --git a/examples/advanced-tensorflow/server.py b/examples/advanced-tensorflow/server.py index e1eb3d4fd8f7..e159a096dc83 100644 --- a/examples/advanced-tensorflow/server.py +++ b/examples/advanced-tensorflow/server.py @@ -4,6 +4,8 @@ import flwr as fl import tensorflow as tf +from flwr_datasets import FederatedDataset + def main() -> None: # Load and compile model for @@ -43,11 +45,11 @@ def main() -> None: def get_evaluate_fn(model): """Return an evaluation function for server-side evaluation.""" - # Load data and model here to avoid the overhead of doing it in `evaluate` itself - (x_train, y_train), _ = tf.keras.datasets.cifar10.load_data() - - # Use the last 5k training examples as a validation set - x_val, y_val = x_train[45000:50000], y_train[45000:50000] + # Load data here to avoid the overhead of doing it in `evaluate` itself + fds = FederatedDataset(dataset="cifar10", partitioners={"train": 10}) + test = fds.load_split("test") + test.set_format("numpy") + x_test, y_test = test["img"] / 255.0, test["label"] # The `evaluate` function will be called after every round def evaluate( @@ -56,7 +58,7 @@ def evaluate( config: Dict[str, fl.common.Scalar], ) -> Optional[Tuple[float, Dict[str, fl.common.Scalar]]]: model.set_weights(parameters) # Update model with the latest parameters - loss, accuracy = model.evaluate(x_val, y_val) + loss, accuracy = model.evaluate(x_test, y_test) return loss, {"accuracy": accuracy} return evaluate diff --git a/examples/android/README.md b/examples/android/README.md index 7931aa96b0c5..f9f2bb93b8dc 100644 --- a/examples/android/README.md +++ b/examples/android/README.md @@ -54,4 +54,4 @@ poetry run ./run.sh Download and install the `flwr_android_client.apk` on each Android device/emulator. The server currently expects a minimum of 4 Android clients, but it can be changed in the `server.py`. -When the Android app runs, add the client ID (between 1-10), the IP and port of your server, and press `Load Dataset`. This will load the local CIFAR10 dataset in memory. Then press `Setup Connection Channel` which will establish connection with the server. Finally, press `Train Federated!` which will start the federated training. +When the Android app runs, add the client ID (between 1-10), the IP and port of your server, and press `Start`. This will load the local CIFAR10 dataset in memory, establish connection with the server, and start the federated training. To abort the federated learning process, press `Stop`. You can clear and refresh the log messages by pressing `Clear` and `Refresh` buttons respectively. diff --git a/examples/android/client/app/src/main/java/flwr/android_client/MainActivity.java b/examples/android/client/app/src/main/java/flwr/android_client/MainActivity.java index cbf804140954..4de369a7e684 100644 --- a/examples/android/client/app/src/main/java/flwr/android_client/MainActivity.java +++ b/examples/android/client/app/src/main/java/flwr/android_client/MainActivity.java @@ -60,7 +60,7 @@ public void onChanged(List workInfos) { if (workInfos.size() > 0) { WorkInfo info = workInfos.get(0); int progress = info.getProgress().getInt("progress", -1); - // You can recieve any message from the Worker Thread + // You can receive any message from the Worker Thread refreshRecyclerView(); } } diff --git a/examples/android/pyproject.toml b/examples/android/pyproject.toml index 2b9cd8c978a7..0371f7208292 100644 --- a/examples/android/pyproject.toml +++ b/examples/android/pyproject.toml @@ -6,10 +6,10 @@ build-backend = "poetry.masonry.api" name = "android_flwr_tensorflow" version = "0.1.0" description = "Android Example" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" flwr = ">=1.0,<2.0" -tensorflow-cpu = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\""} -tensorflow-macos = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\""} +tensorflow-cpu = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\"" } +tensorflow-macos = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\"" } diff --git a/examples/app-pytorch/README.md b/examples/app-pytorch/README.md new file mode 100644 index 000000000000..14de3c7d632e --- /dev/null +++ b/examples/app-pytorch/README.md @@ -0,0 +1,75 @@ +# Flower App (PyTorch) 🧪 + +> 🧪 = This example covers experimental features that might change in future versions of Flower +> Please consult the regular PyTorch code examples ([quickstart](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch), [advanced](https://github.com/adap/flower/tree/main/examples/advanced-pytorch)) to learn how to use Flower with PyTorch. + +The following steps describe how to start a long-running Flower server (SuperLink) and then run a Flower App (consisting of a `ClientApp` and a `ServerApp`). + +## Preconditions + +Let's assume the following project structure: + +```bash +$ tree . +. +├── client.py # <-- contains `ClientApp` +├── server.py # <-- contains `ServerApp` +├── server_workflow.py # <-- contains `ServerApp` with workflow +├── server_custom.py # <-- contains `ServerApp` with custom main function +├── task.py # <-- task-specific code (model, data) +└── requirements.txt # <-- dependencies +``` + +## Install dependencies + +```bash +pip install -r requirements.txt +``` + +## Run a simulation + +```bash +flower-simulation --server-app server:app --client-app client:app --num-supernodes 2 +``` + +## Run a deployment + +### Start the long-running Flower server (SuperLink) + +```bash +flower-superlink --insecure +``` + +### Start the long-running Flower client (SuperNode) + +In a new terminal window, start the first long-running Flower client: + +```bash +flower-client-app client:app --insecure +``` + +In yet another new terminal window, start the second long-running Flower client: + +```bash +flower-client-app client:app --insecure +``` + +### Run the Flower App + +With both the long-running server (SuperLink) and two clients (SuperNode) up and running, we can now run the actual Flower App: + +```bash +flower-server-app server:app --insecure +``` + +Or, to try the workflow example, run: + +```bash +flower-server-app server_workflow:app --insecure +``` + +Or, to try the custom server function example, run: + +```bash +flower-server-app server_custom:app --insecure +``` diff --git a/examples/app-pytorch/client.py b/examples/app-pytorch/client.py new file mode 100644 index 000000000000..ebbe977ecab1 --- /dev/null +++ b/examples/app-pytorch/client.py @@ -0,0 +1,51 @@ +from flwr.client import ClientApp, NumPyClient + +from task import ( + Net, + DEVICE, + load_data, + get_weights, + set_weights, + train, + test, +) + + +# Load model and data (simple CNN, CIFAR-10) +net = Net().to(DEVICE) +trainloader, testloader = load_data() + + +# Define FlowerClient and client_fn +class FlowerClient(NumPyClient): + + def fit(self, parameters, config): + set_weights(net, parameters) + results = train(net, trainloader, testloader, epochs=1, device=DEVICE) + return get_weights(net), len(trainloader.dataset), results + + def evaluate(self, parameters, config): + set_weights(net, parameters) + loss, accuracy = test(net, testloader) + return loss, len(testloader.dataset), {"accuracy": accuracy} + + +def client_fn(cid: str): + """Create and return an instance of Flower `Client`.""" + return FlowerClient().to_client() + + +# Flower ClientApp +app = ClientApp( + client_fn=client_fn, +) + + +# Legacy mode +if __name__ == "__main__": + from flwr.client import start_client + + start_client( + server_address="127.0.0.1:8080", + client=FlowerClient().to_client(), + ) diff --git a/examples/app-pytorch/client_low_level.py b/examples/app-pytorch/client_low_level.py new file mode 100644 index 000000000000..feea1ee658fe --- /dev/null +++ b/examples/app-pytorch/client_low_level.py @@ -0,0 +1,35 @@ +from flwr.client import ClientApp +from flwr.common import Message, Context + + +def hello_world_mod(msg, ctx, call_next) -> Message: + print("Hello, ...[pause for dramatic effect]...") + out = call_next(msg, ctx) + print("...[pause was long enough]... World!") + return out + + +# Flower ClientApp +app = ClientApp( + mods=[ + hello_world_mod, + ], +) + + +@app.train() +def train(msg: Message, ctx: Context): + print("`train` is not implemented, echoing original message") + return msg.create_reply(msg.content, ttl="") + + +@app.evaluate() +def eval(msg: Message, ctx: Context): + print("`evaluate` is not implemented, echoing original message") + return msg.create_reply(msg.content, ttl="") + + +@app.query() +def query(msg: Message, ctx: Context): + print("`query` is not implemented, echoing original message") + return msg.create_reply(msg.content, ttl="") diff --git a/examples/app-pytorch/pyproject.toml b/examples/app-pytorch/pyproject.toml new file mode 100644 index 000000000000..e47dd2db949d --- /dev/null +++ b/examples/app-pytorch/pyproject.toml @@ -0,0 +1,17 @@ +[build-system] +requires = ["poetry-core>=1.4.0"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "app-pytorch" +version = "0.1.0" +description = "Multi-Tenant Federated Learning with Flower and PyTorch" +authors = ["The Flower Authors "] + +[tool.poetry.dependencies] +python = "^3.8" +# Mandatory dependencies +flwr-nightly = { version = "1.8.0.dev20240309", extras = ["simulation"] } +flwr-datasets = { version = "0.0.2", extras = ["vision"] } +torch = "2.2.1" +torchvision = "0.17.1" diff --git a/examples/app-pytorch/requirements.txt b/examples/app-pytorch/requirements.txt new file mode 100644 index 000000000000..016a84043cbe --- /dev/null +++ b/examples/app-pytorch/requirements.txt @@ -0,0 +1,4 @@ +flwr-nightly[simulation]==1.8.0.dev20240309 +flwr-datasets[vision]==0.0.2 +torch==2.2.1 +torchvision==0.17.1 diff --git a/examples/mt-pytorch/start_server.py b/examples/app-pytorch/server.py similarity index 64% rename from examples/mt-pytorch/start_server.py rename to examples/app-pytorch/server.py index d96edd7d45ad..0b4ad1ddba46 100644 --- a/examples/mt-pytorch/start_server.py +++ b/examples/app-pytorch/server.py @@ -1,7 +1,10 @@ from typing import List, Tuple -import flwr as fl -from flwr.common import Metrics +from flwr.server import ServerApp, ServerConfig +from flwr.server.strategy import FedAvg +from flwr.common import Metrics, ndarrays_to_parameters + +from task import Net, get_weights # Define metric aggregation function @@ -25,17 +28,38 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: } +# Initialize model parameters +ndarrays = get_weights(Net()) +parameters = ndarrays_to_parameters(ndarrays) + + # Define strategy -strategy = fl.server.strategy.FedAvg( +strategy = FedAvg( fraction_fit=1.0, # Select all available clients fraction_evaluate=0.0, # Disable evaluation min_available_clients=2, fit_metrics_aggregation_fn=weighted_average, + initial_parameters=parameters, ) -# Start Flower server -fl.server.start_server( - server_address="0.0.0.0:9092", - config=fl.server.ServerConfig(num_rounds=3), + +# Define config +config = ServerConfig(num_rounds=3) + + +# Flower ServerApp +app = ServerApp( + config=config, strategy=strategy, ) + + +# Legacy mode +if __name__ == "__main__": + from flwr.server import start_server + + start_server( + server_address="0.0.0.0:8080", + config=config, + strategy=strategy, + ) diff --git a/examples/app-pytorch/server_custom.py b/examples/app-pytorch/server_custom.py new file mode 100644 index 000000000000..0c2851e2afee --- /dev/null +++ b/examples/app-pytorch/server_custom.py @@ -0,0 +1,148 @@ +from typing import List, Tuple, Dict +import random +import time + +import flwr as fl +from flwr.common import ( + Context, + FitIns, + ndarrays_to_parameters, + parameters_to_ndarrays, + NDArrays, + Code, + Message, + MessageType, + Metrics, +) +from flwr.common.recordset_compat import fitins_to_recordset, recordset_to_fitres +from flwr.server import Driver, History +from flwr.server.strategy.aggregate import aggregate + +from task import Net, get_weights + + +# Define metric aggregation function +def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: + examples = [num_examples for num_examples, _ in metrics] + + # Multiply accuracy of each client by number of examples used + train_losses = [num_examples * m["train_loss"] for num_examples, m in metrics] + train_accuracies = [ + num_examples * m["train_accuracy"] for num_examples, m in metrics + ] + val_losses = [num_examples * m["val_loss"] for num_examples, m in metrics] + val_accuracies = [num_examples * m["val_accuracy"] for num_examples, m in metrics] + + # Aggregate and return custom metric (weighted average) + return { + "train_loss": sum(train_losses) / sum(examples), + "train_accuracy": sum(train_accuracies) / sum(examples), + "val_loss": sum(val_losses) / sum(examples), + "val_accuracy": sum(val_accuracies) / sum(examples), + } + + +# Run via `flower-server-app server:app` +app = fl.server.ServerApp() + + +@app.main() +def main(driver: Driver, context: Context) -> None: + """.""" + print("RUNNING!!!!!") + + num_client_nodes_per_round = 2 + sleep_time = 1 + num_rounds = 3 + parameters = ndarrays_to_parameters(get_weights(net=Net())) + + history = History() + for server_round in range(num_rounds): + print(f"Commencing server round {server_round + 1}") + + # List of sampled node IDs in this round + sampled_nodes: List[int] = [] + + # The Driver API might not immediately return enough client node IDs, so we + # loop and wait until enough client nodes are available. + while True: + all_node_ids = driver.get_node_ids() + + print(f"Got {len(all_node_ids)} client nodes: {all_node_ids}") + if len(all_node_ids) >= num_client_nodes_per_round: + # Sample client nodes + sampled_nodes = random.sample(all_node_ids, num_client_nodes_per_round) + break + time.sleep(3) + + # Log sampled node IDs + print(f"Sampled {len(sampled_nodes)} node IDs: {sampled_nodes}") + + # Schedule a task for all sampled nodes + fit_ins: FitIns = FitIns(parameters=parameters, config={}) + recordset = fitins_to_recordset(fitins=fit_ins, keep_input=True) + + messages = [] + for node_id in sampled_nodes: + message = driver.create_message( + content=recordset, + message_type=MessageType.TRAIN, + dst_node_id=node_id, + group_id=str(server_round), + ttl="", + ) + messages.append(message) + + message_ids = driver.push_messages(messages) + print(f"Pushed {len(message_ids)} messages: {message_ids}") + + # Wait for results, ignore empty message_ids + message_ids = [message_id for message_id in message_ids if message_id != ""] + + all_replies: List[Message] = [] + while True: + replies = driver.pull_messages(message_ids=message_ids) + print(f"Got {len(replies)} results") + all_replies += replies + if len(all_replies) == len(message_ids): + break + time.sleep(3) + + # Collect correct results + all_fitres = [ + recordset_to_fitres(msg.content, keep_input=True) for msg in all_replies + ] + print(f"Received {len(all_fitres)} results") + + weights_results: List[Tuple[NDArrays, int]] = [] + metrics_results: List[Tuple[int, Dict]] = [] + for fitres in all_fitres: + print(f"num_examples: {fitres.num_examples}, status: {fitres.status.code}") + + # Aggregate only if the status is OK + if fitres.status.code != Code.OK: + continue + weights_results.append( + (parameters_to_ndarrays(fitres.parameters), fitres.num_examples) + ) + metrics_results.append((fitres.num_examples, fitres.metrics)) + + # Aggregate parameters (FedAvg) + parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results)) + parameters = parameters_aggregated + + # Aggregate metrics + metrics_aggregated = weighted_average(metrics_results) + history.add_metrics_distributed_fit( + server_round=server_round, metrics=metrics_aggregated + ) + print("Round ", server_round, " metrics: ", metrics_aggregated) + + # Slow down the start of the next round + time.sleep(sleep_time) + + print("app_fit: losses_distributed %s", str(history.losses_distributed)) + print("app_fit: metrics_distributed_fit %s", str(history.metrics_distributed_fit)) + print("app_fit: metrics_distributed %s", str(history.metrics_distributed)) + print("app_fit: losses_centralized %s", str(history.losses_centralized)) + print("app_fit: metrics_centralized %s", str(history.metrics_centralized)) diff --git a/examples/app-pytorch/server_low_level.py b/examples/app-pytorch/server_low_level.py new file mode 100644 index 000000000000..560babac1b95 --- /dev/null +++ b/examples/app-pytorch/server_low_level.py @@ -0,0 +1,52 @@ +from typing import List, Tuple, Dict +import random +import time + +import flwr as fl +from flwr.common import Context, NDArrays, Message, MessageType, Metrics, RecordSet +from flwr.server import Driver + + +# Run via `flower-server-app server:app` +app = fl.server.ServerApp() + + +@app.main() +def main(driver: Driver, context: Context) -> None: + """This is a stub example that simply sends and receives messages.""" + print("Starting test run") + for server_round in range(3): + print(f"Commencing server round {server_round + 1}") + + # Get node IDs + node_ids = driver.get_node_ids() + + # Create messages + recordset = RecordSet() + messages = [] + for node_id in node_ids: + message = driver.create_message( + content=recordset, + message_type=MessageType.TRAIN, + dst_node_id=node_id, + group_id=str(server_round), + ttl="", + ) + messages.append(message) + + # Send messages + message_ids = driver.push_messages(messages) + print(f"Pushed {len(message_ids)} messages: {message_ids}") + + # Wait for results, ignore empty message_ids + message_ids = [message_id for message_id in message_ids if message_id != ""] + all_replies: List[Message] = [] + while True: + replies = driver.pull_messages(message_ids=message_ids) + print(f"Got {len(replies)} results") + all_replies += replies + if len(all_replies) == len(message_ids): + break + time.sleep(3) + + print(f"Received {len(all_replies)} results") diff --git a/examples/app-pytorch/server_workflow.py b/examples/app-pytorch/server_workflow.py new file mode 100644 index 000000000000..6923010ecf7b --- /dev/null +++ b/examples/app-pytorch/server_workflow.py @@ -0,0 +1,63 @@ +from typing import List, Tuple + +from task import Net, get_weights + +import flwr as fl +from flwr.common import Context, Metrics, ndarrays_to_parameters +from flwr.server import Driver, LegacyContext + + +# Define metric aggregation function +def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: + examples = [num_examples for num_examples, _ in metrics] + + # Multiply accuracy of each client by number of examples used + train_losses = [num_examples * m["train_loss"] for num_examples, m in metrics] + train_accuracies = [ + num_examples * m["train_accuracy"] for num_examples, m in metrics + ] + val_losses = [num_examples * m["val_loss"] for num_examples, m in metrics] + val_accuracies = [num_examples * m["val_accuracy"] for num_examples, m in metrics] + + # Aggregate and return custom metric (weighted average) + return { + "train_loss": sum(train_losses) / sum(examples), + "train_accuracy": sum(train_accuracies) / sum(examples), + "val_loss": sum(val_losses) / sum(examples), + "val_accuracy": sum(val_accuracies) / sum(examples), + } + + +# Initialize model parameters +ndarrays = get_weights(Net()) +parameters = ndarrays_to_parameters(ndarrays) + + +# Define strategy +strategy = fl.server.strategy.FedAvg( + fraction_fit=1.0, # Select all available clients + fraction_evaluate=0.0, # Disable evaluation + min_available_clients=2, + fit_metrics_aggregation_fn=weighted_average, + initial_parameters=parameters, +) + + +# Run via `flower-server-app server_workflow:app` +app = fl.server.ServerApp() + + +@app.main() +def main(driver: Driver, context: Context) -> None: + # Construct the LegacyContext + context = LegacyContext( + state=context.state, + config=fl.server.ServerConfig(num_rounds=3), + strategy=strategy, + ) + + # Create the workflow + workflow = fl.server.workflow.DefaultWorkflow() + + # Execute + workflow(driver, context) diff --git a/examples/app-pytorch/task.py b/examples/app-pytorch/task.py new file mode 100644 index 000000000000..240f290df320 --- /dev/null +++ b/examples/app-pytorch/task.py @@ -0,0 +1,94 @@ +from collections import OrderedDict +from logging import INFO + +import torch +import torch.nn as nn +import torch.nn.functional as F +from flwr.common.logger import log +from torch.utils.data import DataLoader +from torchvision.datasets import CIFAR10 +from torchvision.transforms import Compose, Normalize, ToTensor + + +DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +class Net(nn.Module): + """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" + + def __init__(self) -> None: + super(Net, self).__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = x.view(-1, 16 * 5 * 5) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + return self.fc3(x) + + +def load_data(): + """Load CIFAR-10 (training and test set).""" + trf = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + trainset = CIFAR10("./data", train=True, download=True, transform=trf) + testset = CIFAR10("./data", train=False, download=True, transform=trf) + return DataLoader(trainset, batch_size=32, shuffle=True), DataLoader(testset) + + +def train(net, trainloader, valloader, epochs, device): + """Train the model on the training set.""" + log(INFO, "Starting training...") + net.to(device) # move model to GPU if available + criterion = torch.nn.CrossEntropyLoss().to(device) + optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9) + net.train() + for _ in range(epochs): + for images, labels in trainloader: + images, labels = images.to(device), labels.to(device) + optimizer.zero_grad() + loss = criterion(net(images), labels) + loss.backward() + optimizer.step() + + train_loss, train_acc = test(net, trainloader) + val_loss, val_acc = test(net, valloader) + + results = { + "train_loss": train_loss, + "train_accuracy": train_acc, + "val_loss": val_loss, + "val_accuracy": val_acc, + } + return results + + +def test(net, testloader): + """Validate the model on the test set.""" + net.to(DEVICE) + criterion = torch.nn.CrossEntropyLoss() + correct, loss = 0, 0.0 + with torch.no_grad(): + for images, labels in testloader: + outputs = net(images.to(DEVICE)) + labels = labels.to(DEVICE) + loss += criterion(outputs, labels).item() + correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() + accuracy = correct / len(testloader.dataset) + return loss, accuracy + + +def get_weights(net): + return [val.cpu().numpy() for _, val in net.state_dict().items()] + + +def set_weights(net, parameters): + params_dict = zip(net.state_dict().keys(), parameters) + state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) + net.load_state_dict(state_dict, strict=True) diff --git a/examples/app-secure-aggregation/README.md b/examples/app-secure-aggregation/README.md new file mode 100644 index 000000000000..d1ea7bdc893f --- /dev/null +++ b/examples/app-secure-aggregation/README.md @@ -0,0 +1,93 @@ +# Secure aggregation with Flower (the SecAgg+ protocol) 🧪 + +> 🧪 = This example covers experimental features that might change in future versions of Flower +> Please consult the regular PyTorch code examples ([quickstart](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch), [advanced](https://github.com/adap/flower/tree/main/examples/advanced-pytorch)) to learn how to use Flower with PyTorch. + +The following steps describe how to use Secure Aggregation in flower, with `ClientApp` using `secaggplus_mod` and `ServerApp` using `SecAggPlusWorkflow`. + +## Preconditions + +Let's assume the following project structure: + +```bash +$ tree . +. +├── client.py # Client application using `secaggplus_mod` +├── server.py # Server application using `SecAggPlusWorkflow` +├── workflow_with_log.py # Augmented `SecAggPlusWorkflow` +├── run.sh # Quick start script +├── pyproject.toml # Project dependencies (poetry) +└── requirements.txt # Project dependencies (pip) +``` + +## Installing dependencies + +Project dependencies (such as and `flwr`) are defined in `pyproject.toml`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences. + +### Poetry + +```shell +poetry install +poetry shell +``` + +Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command: + +```shell +poetry run python3 -c "import flwr" +``` + +### pip + +Write the command below in your terminal to install the dependencies according to the configuration file requirements.txt. + +```shell +pip install -r requirements.txt +``` + +If you don't see any errors you're good to go! + +## Run the example with one command (recommended) + +```bash +./run.sh +``` + +## Run the example with the simulation engine + +```bash +flower-simulation --server-app server:app --client-app client:app --num-supernodes 5 +``` + +## Alternatively, run the example (in 7 terminal windows) + +Start the Flower Superlink in one terminal window: + +```bash +flower-superlink --insecure +``` + +Start 5 Flower `ClientApp` in 5 separate terminal windows: + +```bash +flower-client-app client:app --insecure +``` + +Start the Flower `ServerApp`: + +```bash +flower-server-app server:app --insecure --verbose +``` + +## Amend the example for practical usage + +For real-world applications, modify the `workflow` in `server.py` as follows: + +```python +workflow = fl.server.workflow.DefaultWorkflow( + fit_workflow=SecAggPlusWorkflow( + num_shares=, + reconstruction_threshold=, + ) +) +``` diff --git a/examples/app-secure-aggregation/client.py b/examples/app-secure-aggregation/client.py new file mode 100644 index 000000000000..b2fd02ec00d4 --- /dev/null +++ b/examples/app-secure-aggregation/client.py @@ -0,0 +1,34 @@ +import time + +from flwr.client import ClientApp, NumPyClient +from flwr.client.mod import secaggplus_mod +import numpy as np + + +# Define FlowerClient and client_fn +class FlowerClient(NumPyClient): + def fit(self, parameters, config): + # Instead of training and returning model parameters, + # the client directly returns [1.0, 1.0, 1.0] for demonstration purposes. + ret_vec = [np.ones(3)] + # Force a significant delay for testing purposes + if "drop" in config and config["drop"]: + print(f"Client dropped for testing purposes.") + time.sleep(8) + else: + print(f"Client uploading {ret_vec[0]}...") + return ret_vec, 1, {} + + +def client_fn(cid: str): + """Create and return an instance of Flower `Client`.""" + return FlowerClient().to_client() + + +# Flower ClientApp +app = ClientApp( + client_fn=client_fn, + mods=[ + secaggplus_mod, + ], +) diff --git a/examples/app-secure-aggregation/pyproject.toml b/examples/app-secure-aggregation/pyproject.toml new file mode 100644 index 000000000000..84b6502064c8 --- /dev/null +++ b/examples/app-secure-aggregation/pyproject.toml @@ -0,0 +1,14 @@ +[build-system] +requires = ["poetry-core>=1.4.0"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "app-secure-aggregation" +version = "0.1.0" +description = "Flower Secure Aggregation example." +authors = ["The Flower Authors "] + +[tool.poetry.dependencies] +python = "^3.8" +# Mandatory dependencies +flwr-nightly = { version = "1.8.0.dev20240309", extras = ["simulation"] } diff --git a/examples/app-secure-aggregation/requirements.txt b/examples/app-secure-aggregation/requirements.txt new file mode 100644 index 000000000000..5bac63a0d44c --- /dev/null +++ b/examples/app-secure-aggregation/requirements.txt @@ -0,0 +1 @@ +flwr-nightly[simulation]==1.8.0.dev20240309 diff --git a/examples/app-secure-aggregation/run.sh b/examples/app-secure-aggregation/run.sh new file mode 100755 index 000000000000..fa8dc47f26ef --- /dev/null +++ b/examples/app-secure-aggregation/run.sh @@ -0,0 +1,34 @@ +#!/bin/bash +# Kill any currently running client.py processes +pkill -f 'flower-client-app' + +# Kill any currently running flower-superlink processes +pkill -f 'flower-superlink' + +# Start the flower server +echo "Starting flower server in background..." +flower-superlink --insecure > /dev/null 2>&1 & +sleep 2 + +# Number of client processes to start +N=5 # Replace with your desired value + +echo "Starting $N ClientApps in background..." + +# Start N client processes +for i in $(seq 1 $N) +do + flower-client-app --insecure client:app > /dev/null 2>&1 & + sleep 0.1 +done + +echo "Starting ServerApp..." +flower-server-app --insecure server:app --verbose + +echo "Clearing background processes..." + +# Kill any currently running client.py processes +pkill -f 'flower-client-app' + +# Kill any currently running flower-superlink processes +pkill -f 'flower-superlink' diff --git a/examples/app-secure-aggregation/server.py b/examples/app-secure-aggregation/server.py new file mode 100644 index 000000000000..e9737a5a3c7f --- /dev/null +++ b/examples/app-secure-aggregation/server.py @@ -0,0 +1,45 @@ +from flwr.common import Context +from flwr.server import Driver, LegacyContext, ServerApp, ServerConfig +from flwr.server.strategy import FedAvg +from flwr.server.workflow import DefaultWorkflow, SecAggPlusWorkflow + +from workflow_with_log import SecAggPlusWorkflowWithLogs + + +# Define strategy +strategy = FedAvg( + fraction_fit=1.0, # Select all available clients + fraction_evaluate=0.0, # Disable evaluation + min_available_clients=5, +) + + +# Flower ServerApp +app = ServerApp() + + +@app.main() +def main(driver: Driver, context: Context) -> None: + # Construct the LegacyContext + context = LegacyContext( + state=context.state, + config=ServerConfig(num_rounds=3), + strategy=strategy, + ) + + # Create the workflow + workflow = DefaultWorkflow( + fit_workflow=SecAggPlusWorkflowWithLogs( + num_shares=3, + reconstruction_threshold=2, + timeout=5, + ) + # # For real-world applications, use the following code instead + # fit_workflow=SecAggPlusWorkflow( + # num_shares=, + # reconstruction_threshold=, + # ) + ) + + # Execute + workflow(driver, context) diff --git a/examples/app-secure-aggregation/workflow_with_log.py b/examples/app-secure-aggregation/workflow_with_log.py new file mode 100644 index 000000000000..a03ff8c13b6c --- /dev/null +++ b/examples/app-secure-aggregation/workflow_with_log.py @@ -0,0 +1,92 @@ +from flwr.common import Context, log, parameters_to_ndarrays +from logging import INFO +from flwr.server import Driver, LegacyContext +from flwr.server.workflow.secure_aggregation.secaggplus_workflow import ( + SecAggPlusWorkflow, + WorkflowState, +) +import numpy as np +from flwr.common.secure_aggregation.quantization import quantize +from flwr.server.workflow.constant import MAIN_PARAMS_RECORD +import flwr.common.recordset_compat as compat + + +class SecAggPlusWorkflowWithLogs(SecAggPlusWorkflow): + """The SecAggPlusWorkflow augmented for this example. + + This class includes additional logging and modifies one of the FitIns to instruct + the target client to simulate a dropout. + """ + + node_ids = [] + + def __call__(self, driver: Driver, context: Context) -> None: + _quantized = quantize( + [np.ones(3) for _ in range(5)], self.clipping_range, self.quantization_range + ) + log(INFO, "") + log( + INFO, + "################################ Introduction ################################", + ) + log( + INFO, + "In the example, each client will upload a vector [1.0, 1.0, 1.0] instead of", + ) + log(INFO, "model updates for demonstration purposes.") + log( + INFO, + "Client 0 is configured to drop out before uploading the masked vector.", + ) + log(INFO, "After quantization, the raw vectors will look like:") + for i in range(1, 5): + log(INFO, "\t%s from Client %s", _quantized[i], i) + log( + INFO, + "Numbers are rounded to integers stochastically during the quantization", + ) + log(INFO, ", and thus entries may not be identical.") + log( + INFO, + "The above raw vectors are hidden from the driver through adding masks.", + ) + log(INFO, "") + log( + INFO, + "########################## Secure Aggregation Start ##########################", + ) + + super().__call__(driver, context) + + paramsrecord = context.state.parameters_records[MAIN_PARAMS_RECORD] + parameters = compat.parametersrecord_to_parameters(paramsrecord, True) + ndarrays = parameters_to_ndarrays(parameters) + log( + INFO, + "Weighted average of vectors (dequantized): %s", + ndarrays[0], + ) + log( + INFO, + "########################### Secure Aggregation End ###########################", + ) + log(INFO, "") + + def setup_stage( + self, driver: Driver, context: LegacyContext, state: WorkflowState + ) -> bool: + ret = super().setup_stage(driver, context, state) + self.node_ids = list(state.active_node_ids) + state.nid_to_fitins[self.node_ids[0]].configs_records["fitins.config"][ + "drop" + ] = True + return ret + + def collect_masked_vectors_stage( + self, driver: Driver, context: LegacyContext, state: WorkflowState + ) -> bool: + ret = super().collect_masked_vectors_stage(driver, context, state) + for node_id in state.sampled_node_ids - state.active_node_ids: + log(INFO, "Client %s dropped out.", self.node_ids.index(node_id)) + log(INFO, "Obtained sum of masked vectors: %s", state.aggregate_ndarrays[1]) + return ret diff --git a/examples/custom-metrics/README.md b/examples/custom-metrics/README.md new file mode 100644 index 000000000000..317fb6336106 --- /dev/null +++ b/examples/custom-metrics/README.md @@ -0,0 +1,106 @@ +# Flower Example using Custom Metrics + +This simple example demonstrates how to calculate custom metrics over multiple clients beyond the traditional ones available in the ML frameworks. In this case, it demonstrates the use of ready-available `scikit-learn` metrics: accuracy, recall, precision, and f1-score. + +Once both the test values (`y_test`) and the predictions (`y_pred`) are available on the client side (`client.py`), other metrics or custom ones are possible to be calculated. + +The main takeaways of this implementation are: + +- the use of the `output_dict` on the client side - inside `evaluate` method on `client.py` +- the use of the `evaluate_metrics_aggregation_fn` - to aggregate the metrics on the server side, part of the `strategy` on `server.py` + +This example is based on the `quickstart-tensorflow` with CIFAR-10, source [here](https://flower.ai/docs/quickstart-tensorflow.html), with the addition of [Flower Datasets](https://flower.ai/docs/datasets/index.html) to retrieve the CIFAR-10. + +Using the CIFAR-10 dataset for classification, this is a multi-class classification problem, thus some changes on how to calculate the metrics using `average='micro'` and `np.argmax` is required. For binary classification, this is not required. Also, for unsupervised learning tasks, such as using a deep autoencoder, a custom metric based on reconstruction error could be implemented on client side. + +## Project Setup + +Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you: + +```shell +git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/custom-metrics . && rm -rf flower && cd custom-metrics +``` + +This will create a new directory called `custom-metrics` containing the following files: + +```shell +-- pyproject.toml +-- requirements.txt +-- client.py +-- server.py +-- run.sh +-- README.md +``` + +### Installing Dependencies + +Project dependencies (such as `scikit-learn`, `tensorflow` and `flwr`) are defined in `pyproject.toml` and `requirements.txt`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)) or [pip](https://pip.pypa.io/en/latest/development/), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences. + +#### Poetry + +```shell +poetry install +poetry shell +``` + +Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command: + +```shell +poetry run python3 -c "import flwr" +``` + +If you don't see any errors you're good to go! + +#### pip + +Write the command below in your terminal to install the dependencies according to the configuration file requirements.txt. + +```shell +python -m venv venv +source venv/bin/activate +pip install -r requirements.txt +``` + +## Run Federated Learning with Custom Metrics + +Afterwards you are ready to start the Flower server as well as the clients. You can simply start the server in a terminal as follows: + +```shell +python server.py +``` + +Now you are ready to start the Flower clients which will participate in the learning. To do so simply open two more terminals and run the following command in each: + +```shell +python client.py +``` + +Alternatively you can run all of it in one shell as follows: + +```shell +python server.py & +# Wait for a few seconds to give the server enough time to start, then: +python client.py & +python client.py +``` + +or + +```shell +chmod +x run.sh +./run.sh +``` + +You will see that Keras is starting a federated training. Have a look to the [Flower Quickstarter documentation](https://flower.ai/docs/quickstart-tensorflow.html) for a detailed explanation. You can add `steps_per_epoch=3` to `model.fit()` if you just want to evaluate that everything works without having to wait for the client-side training to finish (this will save you a lot of time during development). + +Running `run.sh` will result in the following output (after 3 rounds): + +```shell +INFO flwr 2024-01-17 17:45:23,794 | app.py:228 | app_fit: metrics_distributed { + 'accuracy': [(1, 0.10000000149011612), (2, 0.10000000149011612), (3, 0.3393000066280365)], + 'acc': [(1, 0.1), (2, 0.1), (3, 0.3393)], + 'rec': [(1, 0.1), (2, 0.1), (3, 0.3393)], + 'prec': [(1, 0.1), (2, 0.1), (3, 0.3393)], + 'f1': [(1, 0.10000000000000002), (2, 0.10000000000000002), (3, 0.3393)] +} +``` diff --git a/examples/custom-metrics/client.py b/examples/custom-metrics/client.py new file mode 100644 index 000000000000..6a194e92cdce --- /dev/null +++ b/examples/custom-metrics/client.py @@ -0,0 +1,73 @@ +import os + +import flwr as fl +import numpy as np +import tensorflow as tf +from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score +from flwr_datasets import FederatedDataset + + +# Make TensorFlow log less verbose +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + + +# Load model (MobileNetV2) +model = tf.keras.applications.MobileNetV2((32, 32, 3), classes=10, weights=None) +model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"]) + +# Load data with Flower Datasets (CIFAR-10) +fds = FederatedDataset(dataset="cifar10", partitioners={"train": 10}) +train = fds.load_split("train") +test = fds.load_split("test") + +# Using Numpy format +train_np = train.with_format("numpy") +test_np = test.with_format("numpy") +x_train, y_train = train_np["img"], train_np["label"] +x_test, y_test = test_np["img"], test_np["label"] + + +# Method for extra learning metrics calculation +def eval_learning(y_test, y_pred): + acc = accuracy_score(y_test, y_pred) + rec = recall_score( + y_test, y_pred, average="micro" + ) # average argument required for multi-class + prec = precision_score(y_test, y_pred, average="micro") + f1 = f1_score(y_test, y_pred, average="micro") + return acc, rec, prec, f1 + + +# Define Flower client +class FlowerClient(fl.client.NumPyClient): + def get_parameters(self, config): + return model.get_weights() + + def fit(self, parameters, config): + model.set_weights(parameters) + model.fit(x_train, y_train, epochs=1, batch_size=32) + return model.get_weights(), len(x_train), {} + + def evaluate(self, parameters, config): + model.set_weights(parameters) + loss, accuracy = model.evaluate(x_test, y_test) + y_pred = model.predict(x_test) + y_pred = np.argmax(y_pred, axis=1).reshape( + -1, 1 + ) # MobileNetV2 outputs 10 possible classes, argmax returns just the most probable + + acc, rec, prec, f1 = eval_learning(y_test, y_pred) + output_dict = { + "accuracy": accuracy, # accuracy from tensorflow model.evaluate + "acc": acc, + "rec": rec, + "prec": prec, + "f1": f1, + } + return loss, len(x_test), output_dict + + +# Start Flower client +fl.client.start_client( + server_address="127.0.0.1:8080", client=FlowerClient().to_client() +) diff --git a/examples/custom-metrics/pyproject.toml b/examples/custom-metrics/pyproject.toml new file mode 100644 index 000000000000..51c29e213d81 --- /dev/null +++ b/examples/custom-metrics/pyproject.toml @@ -0,0 +1,19 @@ +[build-system] +requires = ["poetry-core>=1.4.0"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "custom-metrics" +version = "0.1.0" +description = "Federated Learning with Flower and Custom Metrics" +authors = [ + "The Flower Authors ", + "Gustavo Bertoli ", +] + +[tool.poetry.dependencies] +python = ">=3.8,<3.11" +flwr = ">=1.0,<2.0" +flwr-datasets = { version = "*", extras = ["vision"] } +scikit-learn = "^1.2.2" +tensorflow = "==2.12.0" diff --git a/examples/custom-metrics/requirements.txt b/examples/custom-metrics/requirements.txt new file mode 100644 index 000000000000..69d867c5f287 --- /dev/null +++ b/examples/custom-metrics/requirements.txt @@ -0,0 +1,4 @@ +flwr>=1.0,<2.0 +flwr-datasets[vision] +scikit-learn>=1.2.2 +tensorflow==2.12.0 diff --git a/examples/custom-metrics/run.sh b/examples/custom-metrics/run.sh new file mode 100755 index 000000000000..c64f362086aa --- /dev/null +++ b/examples/custom-metrics/run.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +echo "Starting server" +python server.py & +sleep 3 # Sleep for 3s to give the server enough time to start + +for i in `seq 0 1`; do + echo "Starting client $i" + python client.py & +done + +# This will allow you to use CTRL+C to stop all background processes +trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM +# Wait for all background processes to complete +wait diff --git a/examples/custom-metrics/server.py b/examples/custom-metrics/server.py new file mode 100644 index 000000000000..f8420bf51f16 --- /dev/null +++ b/examples/custom-metrics/server.py @@ -0,0 +1,58 @@ +import flwr as fl +import numpy as np + + +# Define metrics aggregation function +def average_metrics(metrics): + """Aggregate metrics from multiple clients by calculating mean averages. + + Parameters: + - metrics (list): A list containing tuples, where each tuple represents metrics for a client. + Each tuple is structured as (num_examples, metric), where: + - num_examples (int): The number of examples used to compute the metrics. + - metric (dict): A dictionary containing custom metrics provided as `output_dict` + in the `evaluate` method from `client.py`. + + Returns: + A dictionary with the aggregated metrics, calculating mean averages. The keys of the + dictionary represent different metrics, including: + - 'accuracy': Mean accuracy calculated by TensorFlow. + - 'acc': Mean accuracy from scikit-learn. + - 'rec': Mean recall from scikit-learn. + - 'prec': Mean precision from scikit-learn. + - 'f1': Mean F1 score from scikit-learn. + + Note: If a weighted average is required, the `num_examples` parameter can be leveraged. + + Example: + Example `metrics` list for two clients after the last round: + [(10000, {'prec': 0.108, 'acc': 0.108, 'f1': 0.108, 'accuracy': 0.1080000028014183, 'rec': 0.108}), + (10000, {'f1': 0.108, 'rec': 0.108, 'accuracy': 0.1080000028014183, 'prec': 0.108, 'acc': 0.108})] + """ + + # Here num_examples are not taken into account by using _ + accuracies_tf = np.mean([metric["accuracy"] for _, metric in metrics]) + accuracies = np.mean([metric["acc"] for _, metric in metrics]) + recalls = np.mean([metric["rec"] for _, metric in metrics]) + precisions = np.mean([metric["prec"] for _, metric in metrics]) + f1s = np.mean([metric["f1"] for _, metric in metrics]) + + return { + "accuracy": accuracies_tf, + "acc": accuracies, + "rec": recalls, + "prec": precisions, + "f1": f1s, + } + + +# Define strategy and the custom aggregation function for the evaluation metrics +strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=average_metrics) + + +# Start Flower server +fl.server.start_server( + server_address="0.0.0.0:8080", + config=fl.server.ServerConfig(num_rounds=3), + strategy=strategy, +) diff --git a/examples/custom-mods/.gitignore b/examples/custom-mods/.gitignore new file mode 100644 index 000000000000..260d28a67c6f --- /dev/null +++ b/examples/custom-mods/.gitignore @@ -0,0 +1,2 @@ +wandb/ +.runs_history/ diff --git a/examples/custom-mods/README.md b/examples/custom-mods/README.md new file mode 100644 index 000000000000..b0ad668c2dec --- /dev/null +++ b/examples/custom-mods/README.md @@ -0,0 +1,339 @@ +# Using custom mods 🧪 + +> 🧪 = This example covers experimental features that might change in future versions of Flower +> Please consult the regular PyTorch code examples ([quickstart](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch), [advanced](https://github.com/adap/flower/tree/main/examples/advanced-pytorch)) to learn how to use Flower with PyTorch. + +The following steps describe how to write custom Flower Mods and use them in a simple example. + +## Writing custom Flower Mods + +### Flower Mods basics + +As described [here](https://flower.ai/docs/framework/how-to-use-built-in-mods.html#what-are-mods), Flower Mods in their simplest form can be described as: + +```python +def basic_mod(msg: Message, context: Context, app: ClientApp) -> Message: + # Do something with incoming Message (or Context) + # before passing to the inner ``ClientApp`` + reply = app(msg, context) + # Do something with outgoing Message (or Context) + # before returning + return reply +``` + +and used when defining the `ClientApp`: + +```python +app = fl.client.ClientApp( + client_fn=client_fn, + mods=[basic_mod], +) +``` + +Note that in this specific case, this mod won't modify anything, and perform FL as usual. + +### WandB Flower Mod + +If we want to write a mod to monitor our client-side training using [Weights & Biases](https://github.com/wandb/wandb), we can follow the steps below. + +First, we need to initialize our W&B project with the correct parameters: + +```python +wandb.init( + project=..., + group=..., + name=..., + id=..., + resume="allow", + reinit=True, +) +``` + +In our case, the group should be the `run_id`, specific to a `ServerApp` run, and the `name` should be the `node_id`. This will make it easy to navigate our W&B project, as for each run we will be able to see the computed results as a whole or for each individual client. + +The `id` needs to be unique, so it will be a combination of `run_id` and `node_id`. + +In the end we have: + +```python +def wandb_mod(msg: Message, context: Context, app: ClientAppCallable) -> Message: + run_id = msg.metadata.run_id + group_name = f"Run ID: {run_id}" + + node_id = str(msg.metadata.dst_node_id) + run_name = f"Node ID: {node_id}" + + wandb.init( + project="Mod Name", + group=group_name, + name=run_name, + id=f"{run_id}_{node_id}", + resume="allow", + reinit=True, + ) +``` + +Now, before the message is processed by the server, we will store the starting time and the round number, in order to compute the time it took the client to perform its fit step. + +```python +server_round = int(msg.metadata.group_id) +start_time = time.time() +``` + +And then, we can send the message to the client: + +```python +reply = app(msg, context) +``` + +And now, with the message we got back, we can gather our metrics: + +```python +if reply.metadata.message_type == MessageType.TRAIN and reply.has_content(): + + time_diff = time.time() - start_time + + metrics = reply.content.configs_records + + results_to_log = dict(metrics.get("fitres.metrics", ConfigsRecord())) + results_to_log["fit_time"] = time_diff +``` + +Note that we store our metrics in the `results_to_log` variable and that we only initialize this variable when our client is sending back fit results (with content in it). + +Finally, we can send our results to W&B using: + +```python +wandb.log(results_to_log, step=int(server_round), commit=True) +``` + +The complete mod becomes: + +```python +def wandb_mod(msg: Message, context: Context, app: ClientAppCallable) -> Message: + server_round = int(msg.metadata.group_id) + + if reply.metadata.message_type == MessageType.TRAIN and server_round == 1: + run_id = msg.metadata.run_id + group_name = f"Run ID: {run_id}" + + node_id = str(msg.metadata.dst_node_id) + run_name = f"Node ID: {node_id}" + + wandb.init( + project="Mod Name", + group=group_name, + name=run_name, + id=f"{run_id}_{node_id}", + resume="allow", + reinit=True, + ) + + start_time = time.time() + + reply = app(msg, context) + + if reply.metadata.message_type == MessageType.TRAIN and reply.has_content(): + + time_diff = time.time() - start_time + + metrics = reply.content.configs_records + + results_to_log = dict(metrics.get("fitres.metrics", ConfigsRecord())) + + results_to_log["fit_time"] = time_diff + + wandb.log(results_to_log, step=int(server_round), commit=True) + + return reply +``` + +And it can be used like: + +```python +app = fl.client.ClientApp( + client_fn=client_fn, + mods=[wandb_mod], +) +``` + +If we want to pass an argument to our mod, we can use a wrapper function: + +```python +def get_wandb_mod(name: str) -> Mod: + def wandb_mod(msg: Message, context: Context, app: ClientAppCallable) -> Message: + server_round = int(msg.metadata.group_id) + + run_id = msg.metadata.run_id + group_name = f"Run ID: {run_id}" + + node_id = str(msg.metadata.dst_node_id) + run_name = f"Node ID: {node_id}" + + wandb.init( + project=name, + group=group_name, + name=run_name, + id=f"{run_id}_{node_id}", + resume="allow", + reinit=True, + ) + + start_time = time.time() + + reply = app(msg, context) + + if reply.metadata.message_type == MessageType.TRAIN and reply.has_content(): + + time_diff = time.time() - start_time + + metrics = reply.content.configs_records + + results_to_log = dict(metrics.get("fitres.metrics", ConfigsRecord())) + + results_to_log["fit_time"] = time_diff + + wandb.log(results_to_log, step=int(server_round), commit=True) + + return reply + + return wandb_mod +``` + +And use it like: + +```python +app = fl.client.ClientApp( + client_fn=client_fn, + mods=[ + get_wandb_mod("Custom mods example"), + ], +) +``` + +### TensorBoard Flower Mod + +The [TensorBoard](https://www.tensorflow.org/tensorboard) Mod will only differ in the initialization and how the data is sent to TensorBoard: + +```python +def get_tensorboard_mod(logdir) -> Mod: + os.makedirs(logdir, exist_ok=True) + + def tensorboard_mod( + msg: Message, context: Context, app: ClientAppCallable + ) -> Message: + logdir_run = os.path.join(logdir, str(msg.metadata.run_id)) + + node_id = str(msg.metadata.dst_node_id) + + server_round = int(msg.metadata.group_id) + + start_time = time.time() + + reply = app(msg, context) + + time_diff = time.time() - start_time + + if reply.metadata.message_type == MessageType.TRAIN and reply.has_content(): + writer = tf.summary.create_file_writer(os.path.join(logdir_run, node_id)) + + metrics = dict( + reply.content.configs_records.get("fitres.metrics", ConfigsRecord()) + ) + + with writer.as_default(step=server_round): + tf.summary.scalar(f"fit_time", time_diff, step=server_round) + for metric in metrics: + tf.summary.scalar( + f"{metric}", + metrics[metric], + step=server_round, + ) + writer.flush() + + return reply + + return tensorboard_mod +``` + +For the initialization, TensorBoard uses a custom directory path, which can, in this case, be passed as an argument to the wrapper function. + +It can be used in the following way: + +```python +app = fl.client.ClientApp( + client_fn=client_fn, + mods=[get_tensorboard_mod(".runs_history/")], +) +``` + +## Running the example + +### Preconditions + +Let's assume the following project structure: + +```bash +$ tree . +. +├── client.py # <-- contains `ClientApp` +├── server.py # <-- contains `ServerApp` +├── task.py # <-- task-specific code (model, data) +└── requirements.txt # <-- dependencies +``` + +### Install dependencies + +```bash +pip install -r requirements.txt +``` + +For [W&B](wandb.ai) you will also need a valid account. + +### Start the long-running Flower server (SuperLink) + +```bash +flower-superlink --insecure +``` + +### Start the long-running Flower client (SuperNode) + +In a new terminal window, start the first long-running Flower client using: + +```bash +flower-client-app client:wandb_app --insecure +``` + +for W&B monitoring, or: + +```bash +flower-client-app client:tb_app --insecure +``` + +for TensorBoard. + +In yet another new terminal window, start the second long-running Flower client (with the mod of your choice): + +```bash +flower-client-app client:{wandb,tb}_app --insecure +``` + +### Run the Flower App + +With both the long-running server (SuperLink) and two clients (SuperNode) up and running, we can now run the actual Flower App: + +```bash +flower-server-app server:app --insecure +``` + +### Check the results + +For W&B, you will need to login to the [website](wandb.ai). + +For TensorBoard, you will need to run the following command in your terminal: + +```sh +tensorboard --logdir +``` + +Where `` needs to be replaced by the directory passed as an argument to the wrapper function (`.runs_history/` by default). diff --git a/examples/custom-mods/client.py b/examples/custom-mods/client.py new file mode 100644 index 000000000000..2b87a24da19d --- /dev/null +++ b/examples/custom-mods/client.py @@ -0,0 +1,160 @@ +import logging +import os +import time + +import flwr as fl +import tensorflow as tf +import wandb +from flwr.common import ConfigsRecord +from flwr.client.typing import ClientAppCallable, Mod +from flwr.common.context import Context +from flwr.common.message import Message +from flwr.common.constant import MessageType + +from task import ( + Net, + DEVICE, + load_data, + get_parameters, + set_parameters, + train, + test, +) + + +class WBLoggingFilter(logging.Filter): + def filter(self, record): + return ( + "login" in record.getMessage() + or "View project at" in record.getMessage() + or "View run at" in record.getMessage() + ) + + +# Load model and data (simple CNN, CIFAR-10) +net = Net().to(DEVICE) +trainloader, testloader = load_data() + + +# Define Flower client +class FlowerClient(fl.client.NumPyClient): + def get_parameters(self, config): + return get_parameters(net) + + def fit(self, parameters, config): + set_parameters(net, parameters) + results = train(net, trainloader, testloader, epochs=1, device=DEVICE) + return get_parameters(net), len(trainloader.dataset), results + + def evaluate(self, parameters, config): + set_parameters(net, parameters) + loss, accuracy = test(net, testloader) + return loss, len(testloader.dataset), {"accuracy": accuracy} + + +def client_fn(cid: str): + return FlowerClient().to_client() + + +def get_wandb_mod(name: str) -> Mod: + def wandb_mod(msg: Message, context: Context, app: ClientAppCallable) -> Message: + """Flower Mod that logs the metrics dictionary returned by the client's fit + function to Weights & Biases.""" + server_round = int(msg.metadata.group_id) + + if server_round == 1 and msg.metadata.message_type == MessageType.TRAIN: + run_id = msg.metadata.run_id + group_name = f"Run ID: {run_id}" + + node_id = str(msg.metadata.dst_node_id) + run_name = f"Node ID: {node_id}" + + wandb.init( + project=name, + group=group_name, + name=run_name, + id=f"{run_id}_{node_id}", + resume="allow", + reinit=True, + ) + + start_time = time.time() + + reply = app(msg, context) + + time_diff = time.time() - start_time + + # if the `ClientApp` just processed a "fit" message, let's log some metrics to W&B + if reply.metadata.message_type == MessageType.TRAIN and reply.has_content(): + + metrics = reply.content.configs_records + + results_to_log = dict(metrics.get("fitres.metrics", ConfigsRecord())) + + results_to_log["fit_time"] = time_diff + + wandb.log(results_to_log, step=int(server_round), commit=True) + + return reply + + return wandb_mod + + +def get_tensorboard_mod(logdir) -> Mod: + os.makedirs(logdir, exist_ok=True) + + def tensorboard_mod( + msg: Message, context: Context, app: ClientAppCallable + ) -> Message: + """Flower Mod that logs the metrics dictionary returned by the client's fit + function to TensorBoard.""" + logdir_run = os.path.join(logdir, str(msg.metadata.run_id)) + + node_id = str(msg.metadata.dst_node_id) + + server_round = int(msg.metadata.group_id) + + start_time = time.time() + + reply = app(msg, context) + + time_diff = time.time() - start_time + + # if the `ClientApp` just processed a "fit" message, let's log some metrics to TensorBoard + if reply.metadata.message_type == MessageType.TRAIN and reply.has_content(): + writer = tf.summary.create_file_writer(os.path.join(logdir_run, node_id)) + + metrics = dict( + reply.content.configs_records.get("fitres.metrics", ConfigsRecord()) + ) + + with writer.as_default(step=server_round): + tf.summary.scalar(f"fit_time", time_diff, step=server_round) + for metric in metrics: + tf.summary.scalar( + f"{metric}", + metrics[metric], + step=server_round, + ) + writer.flush() + + return reply + + return tensorboard_mod + + +# Run via `flower-client-app client:wandb_app` +wandb_app = fl.client.ClientApp( + client_fn=client_fn, + mods=[ + get_wandb_mod("Custom mods example"), + ], +) + +# Run via `flower-client-app client:tb_app` +tb_app = fl.client.ClientApp( + client_fn=client_fn, + mods=[ + get_tensorboard_mod(".runs_history/"), + ], +) diff --git a/examples/mt-pytorch/pyproject.toml b/examples/custom-mods/pyproject.toml similarity index 62% rename from examples/mt-pytorch/pyproject.toml rename to examples/custom-mods/pyproject.toml index 4978035495ea..e690e05bab8f 100644 --- a/examples/mt-pytorch/pyproject.toml +++ b/examples/custom-mods/pyproject.toml @@ -3,14 +3,16 @@ requires = ["poetry-core>=1.4.0"] build-backend = "poetry.core.masonry.api" [tool.poetry] -name = "mt-pytorch" +name = "app-pytorch" version = "0.1.0" description = "Multi-Tenant Federated Learning with Flower and PyTorch" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" -flwr-nightly = {version = ">=1.0,<2.0", extras = ["rest", "simulation"]} +flwr = { path = "../../", develop = true, extras = ["simulation"] } +tensorboard = "2.16.2" torch = "1.13.1" torchvision = "0.14.1" tqdm = "4.65.0" +wandb = "0.16.3" diff --git a/examples/mt-pytorch/requirements.txt b/examples/custom-mods/requirements.txt similarity index 72% rename from examples/mt-pytorch/requirements.txt rename to examples/custom-mods/requirements.txt index ae0a65386f2b..75b2c1135f11 100644 --- a/examples/mt-pytorch/requirements.txt +++ b/examples/custom-mods/requirements.txt @@ -1,4 +1,6 @@ flwr-nightly[rest,simulation]>=1.0, <2.0 +tensorboard==2.16.2 torch==1.13.1 torchvision==0.14.1 tqdm==4.65.0 +wandb==0.16.3 diff --git a/examples/mt-pytorch/start_driver.py b/examples/custom-mods/server.py similarity index 68% rename from examples/mt-pytorch/start_driver.py rename to examples/custom-mods/server.py index 3241a548950a..c2d8a4fe5ee7 100644 --- a/examples/mt-pytorch/start_driver.py +++ b/examples/custom-mods/server.py @@ -9,12 +9,16 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: examples = [num_examples for num_examples, _ in metrics] # Multiply accuracy of each client by number of examples used - train_losses = [num_examples * m["train_loss"] for num_examples, m in metrics] + train_losses = [ + num_examples * float(m["train_loss"]) for num_examples, m in metrics + ] train_accuracies = [ - num_examples * m["train_accuracy"] for num_examples, m in metrics + num_examples * float(m["train_accuracy"]) for num_examples, m in metrics + ] + val_losses = [num_examples * float(m["val_loss"]) for num_examples, m in metrics] + val_accuracies = [ + num_examples * float(m["val_accuracy"]) for num_examples, m in metrics ] - val_losses = [num_examples * m["val_loss"] for num_examples, m in metrics] - val_accuracies = [num_examples * m["val_accuracy"] for num_examples, m in metrics] # Aggregate and return custom metric (weighted average) return { @@ -33,9 +37,9 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: fit_metrics_aggregation_fn=weighted_average, ) -# Start Flower server -fl.driver.start_driver( - server_address="0.0.0.0:9091", + +# Run via `flower-server-app server:app` +app = fl.server.ServerApp( config=fl.server.ServerConfig(num_rounds=3), strategy=strategy, ) diff --git a/examples/mt-pytorch/task.py b/examples/custom-mods/task.py similarity index 100% rename from examples/mt-pytorch/task.py rename to examples/custom-mods/task.py diff --git a/examples/doc/source/_templates/base.html b/examples/doc/source/_templates/base.html index e4fe80720b74..08030fb08c15 100644 --- a/examples/doc/source/_templates/base.html +++ b/examples/doc/source/_templates/base.html @@ -5,7 +5,7 @@ - + {%- if metatags %}{{ metatags }}{% endif -%} @@ -99,6 +99,6 @@ {%- endblock -%} {%- endblock scripts -%} - + diff --git a/examples/doc/source/conf.py b/examples/doc/source/conf.py index 01cbb48c1587..bf177aa5ae24 100644 --- a/examples/doc/source/conf.py +++ b/examples/doc/source/conf.py @@ -22,12 +22,15 @@ # -- Project information ----------------------------------------------------- +import datetime + + project = "Flower" -copyright = "2022 Flower Labs GmbH" +copyright = f"{datetime.date.today().year} Flower Labs GmbH" author = "The Flower Authors" # The full version, including alpha/beta/rc tags -release = "1.7.0" +release = "1.8.0" # -- General configuration --------------------------------------------------- @@ -73,7 +76,7 @@ html_title = f"Flower Examples {release}" html_logo = "_static/flower-logo.png" html_favicon = "_static/favicon.ico" -html_baseurl = "https://flower.dev/docs/examples/" +html_baseurl = "https://flower.ai/docs/examples/" html_theme_options = { # diff --git a/examples/dp-sgd-mnist/pyproject.toml b/examples/dp-sgd-mnist/pyproject.toml index 3158ab3d52f6..161952fd2aa4 100644 --- a/examples/dp-sgd-mnist/pyproject.toml +++ b/examples/dp-sgd-mnist/pyproject.toml @@ -7,14 +7,14 @@ name = "dp-sgd-mnist" version = "0.1.0" description = "Federated training with Tensorflow Privacy" authors = [ - "The Flower Authors ", - "Kaushik Amar Das " + "The Flower Authors ", + "Kaushik Amar Das ", ] [tool.poetry.dependencies] python = ">=3.8,<3.11" # flwr = { path = "../../", develop = true } # Development flwr = ">=1.0,<2.0" -tensorflow-cpu = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\""} -tensorflow-macos = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\""} +tensorflow-cpu = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\"" } +tensorflow-macos = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\"" } tensorflow-privacy = "0.8.10" diff --git a/examples/embedded-devices/Dockerfile b/examples/embedded-devices/Dockerfile index ea63839bc9d6..a85c05c4bb7a 100644 --- a/examples/embedded-devices/Dockerfile +++ b/examples/embedded-devices/Dockerfile @@ -8,6 +8,7 @@ RUN pip3 install --upgrade pip # Install flower RUN pip3 install flwr>=1.0 +RUN pip3 install flwr-datsets>=0.2 RUN pip3 install tqdm==4.65.0 WORKDIR /client diff --git a/examples/embedded-devices/README.md b/examples/embedded-devices/README.md index 4c79eafbbf84..f1c5931b823a 100644 --- a/examples/embedded-devices/README.md +++ b/examples/embedded-devices/README.md @@ -192,7 +192,8 @@ On the machine of your choice, launch the server: # Launch your server. # Will wait for at least 2 clients to be connected, then will train for 3 FL rounds # The command below will sample all clients connected (since sample_fraction=1.0) -python server.py --rounds 3 --min_num_clients 2 --sample_fraction 1.0 # append `--mnist` if you want to use that dataset/model setting +# The server is dataset agnostic (use the same command for MNIST and CIFAR10) +python server.py --rounds 3 --min_num_clients 2 --sample_fraction 1.0 ``` > If you are on macOS with Apple Silicon (i.e. M1, M2 chips), you might encounter a `grpcio`-related issue when launching your server. If you are in a conda environment you can solve this easily by doing: `pip uninstall grpcio` and then `conda install grpcio`. diff --git a/examples/embedded-devices/client_pytorch.py b/examples/embedded-devices/client_pytorch.py index 5d236c9e9389..6bd69c16567e 100644 --- a/examples/embedded-devices/client_pytorch.py +++ b/examples/embedded-devices/client_pytorch.py @@ -6,18 +6,19 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch.utils.data import DataLoader, random_split -from torchvision.datasets import CIFAR10, MNIST +from torch.utils.data import DataLoader from torchvision.transforms import Compose, Normalize, ToTensor from torchvision.models import mobilenet_v3_small from tqdm import tqdm +from flwr_datasets import FederatedDataset + parser = argparse.ArgumentParser(description="Flower Embedded devices") parser.add_argument( "--server_address", type=str, default="0.0.0.0:8080", - help=f"gRPC server address (deafault '0.0.0.0:8080')", + help=f"gRPC server address (default '0.0.0.0:8080')", ) parser.add_argument( "--cid", @@ -28,25 +29,13 @@ parser.add_argument( "--mnist", action="store_true", - help="If you use Raspberry Pi Zero clients (which just have 512MB or RAM) use MNIST", + help="If you use Raspberry Pi Zero clients (which just have 512MB or RAM) use " + "MNIST", ) - warnings.filterwarnings("ignore", category=UserWarning) NUM_CLIENTS = 50 -# a config for mobilenetv2 that works for -# small input sizes (i.e. 32x32 as in CIFAR) -mb2_cfg = [ - (1, 16, 1, 1), - (6, 24, 2, 1), - (6, 32, 3, 2), - (6, 64, 4, 2), - (6, 96, 3, 1), - (6, 160, 3, 2), - (6, 320, 1, 1), -] - class Net(nn.Module): """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz').""" @@ -73,7 +62,9 @@ def train(net, trainloader, optimizer, epochs, device): """Train the model on the training set.""" criterion = torch.nn.CrossEntropyLoss() for _ in range(epochs): - for images, labels in tqdm(trainloader): + for batch in tqdm(trainloader): + batch = list(batch.values()) + images, labels = batch[0], batch[1] optimizer.zero_grad() criterion(net(images.to(device)), labels.to(device)).backward() optimizer.step() @@ -84,7 +75,9 @@ def test(net, testloader, device): criterion = torch.nn.CrossEntropyLoss() correct, loss = 0, 0.0 with torch.no_grad(): - for images, labels in tqdm(testloader): + for batch in tqdm(testloader): + batch = list(batch.values()) + images, labels = batch[0], batch[1] outputs = net(images.to(device)) labels = labels.to(device) loss += criterion(outputs, labels).item() @@ -95,44 +88,33 @@ def test(net, testloader, device): def prepare_dataset(use_mnist: bool): """Get MNIST/CIFAR-10 and return client partitions and global testset.""" - dataset = MNIST if use_mnist else CIFAR10 if use_mnist: + fds = FederatedDataset(dataset="mnist", partitioners={"train": NUM_CLIENTS}) + img_key = "image" norm = Normalize((0.1307,), (0.3081,)) else: + fds = FederatedDataset(dataset="cifar10", partitioners={"train": NUM_CLIENTS}) + img_key = "img" norm = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) - - trf = Compose([ToTensor(), norm]) - trainset = dataset("./data", train=True, download=True, transform=trf) - testset = dataset("./data", train=False, download=True, transform=trf) - - print("Partitioning dataset (IID)...") - - # Split trainset into `num_partitions` trainsets - num_images = len(trainset) // NUM_CLIENTS - partition_len = [num_images] * NUM_CLIENTS - - trainsets = random_split( - trainset, partition_len, torch.Generator().manual_seed(2023) - ) - - val_ratio = 0.1 - - # Create dataloaders with train+val support - train_partitions = [] - val_partitions = [] - for trainset_ in trainsets: - num_total = len(trainset_) - num_val = int(val_ratio * num_total) - num_train = num_total - num_val - - for_train, for_val = random_split( - trainset_, [num_train, num_val], torch.Generator().manual_seed(2023) - ) - - train_partitions.append(for_train) - val_partitions.append(for_val) - - return train_partitions, val_partitions, testset + pytorch_transforms = Compose([ToTensor(), norm]) + + def apply_transforms(batch): + """Apply transforms to the partition from FederatedDataset.""" + batch[img_key] = [pytorch_transforms(img) for img in batch[img_key]] + return batch + + trainsets = [] + validsets = [] + for partition_id in range(NUM_CLIENTS): + partition = fds.load_partition(partition_id, "train") + # Divide data on each node: 90% train, 10% test + partition = partition.train_test_split(test_size=0.1) + partition = partition.with_transform(apply_transforms) + trainsets.append(partition["train"]) + validsets.append(partition["test"]) + testset = fds.load_split("test") + testset = testset.with_transform(apply_transforms) + return trainsets, validsets, testset # Flower client, adapted from Pytorch quickstart/simulation example @@ -148,8 +130,6 @@ def __init__(self, trainset, valset, use_mnist): self.model = Net() else: self.model = mobilenet_v3_small(num_classes=10) - # let's not reduce spatial resolution too early - self.model.features[0][0].stride = (1, 1) # Determine device self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.model.to(self.device) # send model to device @@ -200,15 +180,15 @@ def main(): assert args.cid < NUM_CLIENTS use_mnist = args.mnist - # Download CIFAR-10 dataset and partition it + # Download dataset and partition it trainsets, valsets, _ = prepare_dataset(use_mnist) # Start Flower client setting its associated data partition - fl.client.start_numpy_client( + fl.client.start_client( server_address=args.server_address, client=FlowerClient( trainset=trainsets[args.cid], valset=valsets[args.cid], use_mnist=use_mnist - ), + ).to_client(), ) diff --git a/examples/embedded-devices/client_tf.py b/examples/embedded-devices/client_tf.py index 3457af1c7a66..49c63ce5d9dc 100644 --- a/examples/embedded-devices/client_tf.py +++ b/examples/embedded-devices/client_tf.py @@ -6,6 +6,8 @@ import tensorflow as tf from tensorflow import keras as keras +from flwr_datasets import FederatedDataset + parser = argparse.ArgumentParser(description="Flower Embedded devices") parser.add_argument( "--server_address", @@ -32,30 +34,28 @@ def prepare_dataset(use_mnist: bool): """Download and partitions the CIFAR-10/MNIST dataset.""" if use_mnist: - (x_train, y_train), testset = tf.keras.datasets.mnist.load_data() + fds = FederatedDataset(dataset="mnist", partitioners={"train": NUM_CLIENTS}) + img_key = "image" else: - (x_train, y_train), testset = tf.keras.datasets.cifar10.load_data() + fds = FederatedDataset(dataset="cifar10", partitioners={"train": NUM_CLIENTS}) + img_key = "img" partitions = [] - # We keep all partitions equal-sized in this example - partition_size = math.floor(len(x_train) / NUM_CLIENTS) - for cid in range(NUM_CLIENTS): - # Split dataset into non-overlapping NUM_CLIENT partitions - idx_from, idx_to = int(cid) * partition_size, (int(cid) + 1) * partition_size - - x_train_cid, y_train_cid = ( - x_train[idx_from:idx_to] / 255.0, - y_train[idx_from:idx_to], + for partition_id in range(NUM_CLIENTS): + partition = fds.load_partition(partition_id, "train") + partition.set_format("numpy") + # Divide data on each node: 90% train, 10% test + partition = partition.train_test_split(test_size=0.1) + x_train, y_train = ( + partition["train"][img_key] / 255.0, + partition["train"]["label"], ) - - # now partition into train/validation - # Use 10% of the client's training data for validation - split_idx = math.floor(len(x_train_cid) * 0.9) - - client_train = (x_train_cid[:split_idx], y_train_cid[:split_idx]) - client_val = (x_train_cid[split_idx:], y_train_cid[split_idx:]) - partitions.append((client_train, client_val)) - - return partitions, testset + x_test, y_test = partition["test"][img_key] / 255.0, partition["test"]["label"] + partitions.append(((x_train, y_train), (x_test, y_test))) + data_centralized = fds.load_split("test") + data_centralized.set_format("numpy") + x_centralized = data_centralized[img_key] / 255.0 + y_centralized = data_centralized["label"] + return partitions, (x_centralized, y_centralized) class FlowerClient(fl.client.NumPyClient): @@ -68,7 +68,7 @@ def __init__(self, trainset, valset, use_mnist: bool): # Instantiate model if use_mnist: # small model for MNIST - self.model = model = keras.Sequential( + self.model = keras.Sequential( [ keras.Input(shape=(28, 28, 1)), keras.layers.Conv2D(32, kernel_size=(5, 5), activation="relu"), @@ -118,14 +118,16 @@ def main(): assert args.cid < NUM_CLIENTS use_mnist = args.mnist - # Download CIFAR-10 dataset and partition it + # Download dataset and partition it partitions, _ = prepare_dataset(use_mnist) trainset, valset = partitions[args.cid] # Start Flower client setting its associated data partition - fl.client.start_numpy_client( + fl.client.start_client( server_address=args.server_address, - client=FlowerClient(trainset=trainset, valset=valset, use_mnist=use_mnist), + client=FlowerClient( + trainset=trainset, valset=valset, use_mnist=use_mnist + ).to_client(), ) diff --git a/examples/embedded-devices/requirements_pytorch.txt b/examples/embedded-devices/requirements_pytorch.txt index 797ca6db6244..f859c4efef17 100644 --- a/examples/embedded-devices/requirements_pytorch.txt +++ b/examples/embedded-devices/requirements_pytorch.txt @@ -1,4 +1,5 @@ flwr>=1.0, <2.0 +flwr-datasets[vision]>=0.0.2, <1.0.0 torch==1.13.1 torchvision==0.14.1 tqdm==4.65.0 diff --git a/examples/embedded-devices/requirements_tf.txt b/examples/embedded-devices/requirements_tf.txt index c7068d40b9c2..ff65b9c31648 100644 --- a/examples/embedded-devices/requirements_tf.txt +++ b/examples/embedded-devices/requirements_tf.txt @@ -1,2 +1,3 @@ flwr>=1.0, <2.0 +flwr-datasets[vision]>=0.0.2, <1.0.0 tensorflow >=2.9.1, != 2.11.1 diff --git a/examples/embedded-devices/server.py b/examples/embedded-devices/server.py index 2a15f792297e..2a6194aa5088 100644 --- a/examples/embedded-devices/server.py +++ b/examples/embedded-devices/server.py @@ -30,16 +30,11 @@ default=2, help="Minimum number of available clients required for sampling (default: 2)", ) -parser.add_argument( - "--mnist", - action="store_true", - help="If you use Raspberry Pi Zero clients (which just have 512MB or RAM) use MNIST", -) # Define metric aggregation function def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: - """Thist function averages teh `accuracy` metric sent by the clients in a `evaluate` + """This function averages teh `accuracy` metric sent by the clients in a `evaluate` stage (i.e. clients received the global model and evaluate it on their local validation sets).""" # Multiply accuracy of each client by number of examples used diff --git a/examples/federated-kaplan-meier-fitter/README.md b/examples/federated-kaplan-meier-fitter/README.md new file mode 100644 index 000000000000..1569467d6f82 --- /dev/null +++ b/examples/federated-kaplan-meier-fitter/README.md @@ -0,0 +1,104 @@ +# Flower Example using KaplanMeierFitter + +This is an introductory example on **federated survival analysis** using [Flower](https://flower.ai/) +and [lifelines](https://lifelines.readthedocs.io/en/stable/index.html) library. + +The aim of this example is to estimate the survival function using the +[Kaplan-Meier Estimate](https://en.wikipedia.org/wiki/Kaplan%E2%80%93Meier_estimator) implemented in +lifelines library (see [KaplanMeierFitter](https://lifelines.readthedocs.io/en/stable/fitters/univariate/KaplanMeierFitter.html#lifelines.fitters.kaplan_meier_fitter.KaplanMeierFitter)). The distributed/federated aspect of this example +is the data sending to the server. You can think of it as a federated analytics example. However, it's worth noting that this procedure violates privacy since the raw data is exchanged. + +Finally, many other estimators beyond KaplanMeierFitter can be used with the provided strategy: +AalenJohansenFitter, GeneralizedGammaFitter, LogLogisticFitter, +SplineFitter, and WeibullFitter. + +We also use the [NatualPartitioner](https://flower.ai/docs/datasets/ref-api/flwr_datasets.partitioner.NaturalIdPartitioner.html#flwr_datasets.partitioner.NaturalIdPartitioner) from [Flower Datasets](https://flower.ai/docs/datasets/) to divide the data according to +the group it comes from therefore to simulate the division that might occur. + +

+Survival Function +

+ +## Project Setup + +Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you: + +```shell +$ git clone --depth=1 https://github.com/adap/flower.git _tmp && mv _tmp/examples/federated-kaplan-meier-fitter . && rm -rf _tmp && cd federated-kaplan-meier-fitter +``` + +This will create a new directory called `federated-kaplan-meier-fitter` containing the following files: + +```shell +-- pyproject.toml +-- requirements.txt +-- client.py +-- server.py +-- centralized.py +-- README.md +``` + +### Installing Dependencies + +Project dependencies (such as `lifelines` and `flwr`) are defined in `pyproject.toml` and `requirements.txt`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)) or [pip](https://pip.pypa.io/en/latest/development/), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences. + +#### Poetry + +```shell +poetry install +poetry shell +``` + +Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command: + +```shell +poetry run python3 -c "import flwr" +``` + +If you don't see any errors you're good to go! + +#### pip + +Write the command below in your terminal to install the dependencies according to the configuration file requirements.txt. + +```shell +pip install -r requirements.txt +``` + +## Run Federated Survival Analysis with Flower and lifelines's KaplanMeierFitter + +### Start the long-running Flower server (SuperLink) + +```bash +flower-superlink --insecure +``` + +### Start the long-running Flower client (SuperNode) + +In a new terminal window, start the first long-running Flower client: + +```bash +flower-client-app client:node_1_app --insecure +``` + +In yet another new terminal window, start the second long-running Flower client: + +```bash +flower-client-app client:node_2_app --insecure +``` + +### Run the Flower App + +With both the long-running server (SuperLink) and two clients (SuperNode) up and running, we can now run the actual Flower App: + +```bash +flower-server-app server:app --insecure +``` + +You will see that the server is printing survival function, median survival time and saves the plot with the survival function. + +You can also check that the results match the centralized version. + +```shell +$ python3 centralized.py +``` diff --git a/examples/federated-kaplan-meier-fitter/_static/survival_function_centralized.png b/examples/federated-kaplan-meier-fitter/_static/survival_function_centralized.png new file mode 100644 index 000000000000..b7797f0879f2 Binary files /dev/null and b/examples/federated-kaplan-meier-fitter/_static/survival_function_centralized.png differ diff --git a/examples/federated-kaplan-meier-fitter/_static/survival_function_federated.png b/examples/federated-kaplan-meier-fitter/_static/survival_function_federated.png new file mode 100644 index 000000000000..b7797f0879f2 Binary files /dev/null and b/examples/federated-kaplan-meier-fitter/_static/survival_function_federated.png differ diff --git a/examples/federated-kaplan-meier-fitter/centralized.py b/examples/federated-kaplan-meier-fitter/centralized.py new file mode 100644 index 000000000000..c94edfef9a18 --- /dev/null +++ b/examples/federated-kaplan-meier-fitter/centralized.py @@ -0,0 +1,16 @@ +import matplotlib.pyplot as plt +from lifelines import KaplanMeierFitter +from lifelines.datasets import load_waltons + +if __name__ == "__main__": + X = load_waltons() + fitter = KaplanMeierFitter() + fitter.fit(X["T"], X["E"]) + print("Survival function") + print(fitter.survival_function_) + print("Mean survival time:") + print(fitter.median_survival_time_) + fitter.plot_survival_function() + plt.title("Survival function of fruit flies (Walton's data)", fontsize=16) + plt.savefig("./_static/survival_function_centralized.png", dpi=200) + print("Centralized survival function saved.") diff --git a/examples/federated-kaplan-meier-fitter/client.py b/examples/federated-kaplan-meier-fitter/client.py new file mode 100644 index 000000000000..948492efc575 --- /dev/null +++ b/examples/federated-kaplan-meier-fitter/client.py @@ -0,0 +1,65 @@ +from typing import Dict, List, Tuple + +import flwr as fl +import numpy as np +from datasets import Dataset +from flwr.common import NDArray, NDArrays +from flwr_datasets.partitioner import NaturalIdPartitioner +from lifelines.datasets import load_waltons + + +class FlowerClient(fl.client.NumPyClient): + """Flower client that holds and sends the events and times data. + + Parameters + ---------- + times: NDArray + Times of the `events`. + events: NDArray + Events represented by 0 - no event, 1 - event occurred. + + Raises + ------ + ValueError + If the `times` and `events` are not the same shape. + """ + + def __init__(self, times: NDArray, events: NDArray): + if len(times) != len(events): + raise ValueError("The times and events arrays have to be same shape.") + self._times = times + self._events = events + + def fit( + self, parameters: List[np.ndarray], config: Dict[str, str] + ) -> Tuple[NDArrays, int, Dict]: + return ( + [self._times, self._events], + len(self._times), + {}, + ) + + +# Prepare data +X = load_waltons() +partitioner = NaturalIdPartitioner(partition_by="group") +partitioner.dataset = Dataset.from_pandas(X) + + +def get_client_fn(partition_id: int): + def client_fn(cid: str): + partition = partitioner.load_partition(partition_id).to_pandas() + events = partition["E"].values + times = partition["T"].values + return FlowerClient(times=times, events=events).to_client() + + return client_fn + + +# Run via `flower-client-app client:app` +node_1_app = fl.client.ClientApp( + client_fn=get_client_fn(0), +) +node_2_app = fl.client.ClientApp( + client_fn=get_client_fn(1), +) diff --git a/examples/federated-kaplan-meier-fitter/pyproject.toml b/examples/federated-kaplan-meier-fitter/pyproject.toml new file mode 100644 index 000000000000..8fe354ffb750 --- /dev/null +++ b/examples/federated-kaplan-meier-fitter/pyproject.toml @@ -0,0 +1,18 @@ +[build-system] +requires = ["poetry-core>=1.4.0"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "federated-kaplan-meier-fitter" +version = "0.1.0" +description = "Federated Kaplan Meier Fitter with Flower" +authors = ["The Flower Authors "] +maintainers = ["The Flower Authors "] + +[tool.poetry.dependencies] +python = ">=3.9,<3.11" +flwr-nightly = "*" +flwr-datasets = ">=0.0.2,<1.0.0" +numpy = ">=1.23.2" +pandas = ">=2.0.0" +lifelines = ">=0.28.0" diff --git a/examples/federated-kaplan-meier-fitter/requirements.txt b/examples/federated-kaplan-meier-fitter/requirements.txt new file mode 100644 index 000000000000..cc8146545c7b --- /dev/null +++ b/examples/federated-kaplan-meier-fitter/requirements.txt @@ -0,0 +1,5 @@ +flwr-nightly +flwr-datasets>=0.0.2, <1.0.0 +numpy>=1.23.2 +pandas>=2.0.0 +lifelines>=0.28.0 diff --git a/examples/federated-kaplan-meier-fitter/server.py b/examples/federated-kaplan-meier-fitter/server.py new file mode 100644 index 000000000000..141504ab59c0 --- /dev/null +++ b/examples/federated-kaplan-meier-fitter/server.py @@ -0,0 +1,145 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Strategy that supports many univariate fitters from lifelines library.""" + +from typing import Dict, List, Optional, Tuple, Union, Any + +import numpy as np +import flwr as fl +import matplotlib.pyplot as plt +from flwr.common import ( + FitIns, + Parameters, + Scalar, + EvaluateRes, + EvaluateIns, + FitRes, + parameters_to_ndarrays, +) +from flwr.server.client_manager import ClientManager +from flwr.server.client_proxy import ClientProxy +from flwr.server.strategy import Strategy +from lifelines import KaplanMeierFitter + + +class EventTimeFitterStrategy(Strategy): + """Federated strategy to aggregate the data that consist of events and times. + + It works with the following uni-variate fitters from the lifelines library: + AalenJohansenFitter, GeneralizedGammaFitter, KaplanMeierFitter, LogLogisticFitter, + SplineFitter, WeibullFitter. Note that each of them might require slightly different + initialization but constructed fitter object are required to be passed. + + This strategy recreates the event and time data based on the data received from the + nodes. + + Parameters + ---------- + min_num_clients: int + pass + fitter: Any + uni-variate fitter from lifelines library that works with event, time data e.g. + KaplanMeierFitter + """ + + def __init__(self, min_num_clients: int, fitter: Any): + # Fitter can be access after the federated training ends + self._min_num_clients = min_num_clients + self.fitter = fitter + + def configure_fit( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, FitIns]]: + """Configure the fit method.""" + config = {} + fit_ins = FitIns(parameters, config) + clients = client_manager.sample( + num_clients=client_manager.num_available(), + min_num_clients=self._min_num_clients, + ) + return [(client, fit_ins) for client in clients] + + def aggregate_fit( + self, + server_round: int, + results: List[Tuple[ClientProxy, FitRes]], + failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], + ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + """Merge data and perform the fitting of the fitter from lifelines library. + + Assume just a single federated learning round. Assume the data comes as a list + with two elements of 1-dim numpy arrays of events and times. + """ + remote_data = [ + (parameters_to_ndarrays(fit_res.parameters)) for _, fit_res in results + ] + + combined_times = remote_data[0][0] + combined_events = remote_data[0][1] + + for te_list in remote_data[1:]: + combined_times = np.concatenate((combined_times, te_list[0])) + combined_events = np.concatenate((combined_events, te_list[1])) + + args_sorted = np.argsort(combined_times) + sorted_times = combined_times[args_sorted] + sorted_events = combined_events[args_sorted] + self.fitter.fit(sorted_times, sorted_events) + print("Survival function:") + print(self.fitter.survival_function_) + self.fitter.plot_survival_function() + plt.title("Survival function of fruit flies (Walton's data)", fontsize=16) + plt.savefig("./_static/survival_function_federated.png", dpi=200) + print("Mean survival time:") + print(self.fitter.median_survival_time_) + return None, {} + + # The methods below return None or empty results. + # They need to be implemented to since the methods are abstract in the parent class + def initialize_parameters( + self, client_manager: Optional[ClientManager] = None + ) -> Optional[Parameters]: + """No parameter initialization is needed.""" + return None + + def evaluate( + self, server_round: int, parameters: Parameters + ) -> Optional[Tuple[float, Dict[str, Scalar]]]: + """No centralized evaluation.""" + return None + + def aggregate_evaluate( + self, + server_round: int, + results: List[Tuple[ClientProxy, EvaluateRes]], + failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> Tuple[Optional[float], Dict[str, Scalar]]: + """No federated evaluation.""" + return None, {} + + def configure_evaluate( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, EvaluateIns]]: + """No federated evaluation.""" + return [] + + +fitter = KaplanMeierFitter() # You can choose other method that work on E, T data +strategy = EventTimeFitterStrategy(min_num_clients=2, fitter=fitter) + +app = fl.server.ServerApp( + config=fl.server.ServerConfig(num_rounds=1), + strategy=strategy, +) diff --git a/examples/fl-dp-sa/README.md b/examples/fl-dp-sa/README.md new file mode 100644 index 000000000000..99a0a7e50980 --- /dev/null +++ b/examples/fl-dp-sa/README.md @@ -0,0 +1,22 @@ +# fl_dp_sa + +This is a simple example that utilizes central differential privacy with client-side fixed clipping and secure aggregation. +Note: This example is designed for a small number of rounds and is intended for demonstration purposes. + +## Install dependencies + +```bash +# Using pip +pip install . + +# Or using Poetry +poetry install +``` + +## Run + +The example uses the CIFAR-10 dataset with a total of 100 clients, with 20 clients sampled in each round. The hyperparameters for DP and SecAgg are specified in `server.py`. + +```shell +flower-simulation --server-app fl_dp_sa.server:app --client-app fl_dp_sa.client:app --num-supernodes 100 +``` diff --git a/examples/fl-dp-sa/fl_dp_sa/__init__.py b/examples/fl-dp-sa/fl_dp_sa/__init__.py new file mode 100644 index 000000000000..741260348ab8 --- /dev/null +++ b/examples/fl-dp-sa/fl_dp_sa/__init__.py @@ -0,0 +1 @@ +"""fl_dp_sa: A Flower / PyTorch app.""" diff --git a/examples/fl-dp-sa/fl_dp_sa/client.py b/examples/fl-dp-sa/fl_dp_sa/client.py new file mode 100644 index 000000000000..104264158833 --- /dev/null +++ b/examples/fl-dp-sa/fl_dp_sa/client.py @@ -0,0 +1,43 @@ +"""fl_dp_sa: A Flower / PyTorch app.""" + +from flwr.client import ClientApp, NumPyClient +from flwr.client.mod import fixedclipping_mod, secaggplus_mod + +from fl_dp_sa.task import DEVICE, Net, get_weights, load_data, set_weights, test, train + + +# Load model and data (simple CNN, CIFAR-10) +net = Net().to(DEVICE) + + +# Define FlowerClient and client_fn +class FlowerClient(NumPyClient): + def __init__(self, trainloader, testloader) -> None: + self.trainloader = trainloader + self.testloader = testloader + + def fit(self, parameters, config): + set_weights(net, parameters) + results = train(net, self.trainloader, self.testloader, epochs=1, device=DEVICE) + return get_weights(net), len(self.trainloader.dataset), results + + def evaluate(self, parameters, config): + set_weights(net, parameters) + loss, accuracy = test(net, self.testloader) + return loss, len(self.testloader.dataset), {"accuracy": accuracy} + + +def client_fn(cid: str): + """Create and return an instance of Flower `Client`.""" + trainloader, testloader = load_data(partition_id=int(cid)) + return FlowerClient(trainloader, testloader).to_client() + + +# Flower ClientApp +app = ClientApp( + client_fn=client_fn, + mods=[ + secaggplus_mod, + fixedclipping_mod, + ], +) diff --git a/examples/fl-dp-sa/fl_dp_sa/server.py b/examples/fl-dp-sa/fl_dp_sa/server.py new file mode 100644 index 000000000000..f7da75997e98 --- /dev/null +++ b/examples/fl-dp-sa/fl_dp_sa/server.py @@ -0,0 +1,77 @@ +"""fl_dp_sa: A Flower / PyTorch app.""" + +from typing import List, Tuple + +from flwr.server import Driver, LegacyContext, ServerApp, ServerConfig +from flwr.common import Context, Metrics, ndarrays_to_parameters +from flwr.server.strategy import ( + DifferentialPrivacyClientSideFixedClipping, + FedAvg, +) +from flwr.server.workflow import DefaultWorkflow, SecAggPlusWorkflow + +from fl_dp_sa.task import Net, get_weights + + +# Define metric aggregation function +def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: + examples = [num_examples for num_examples, _ in metrics] + + # Multiply accuracy of each client by number of examples used + train_losses = [num_examples * m["train_loss"] for num_examples, m in metrics] + train_accuracies = [ + num_examples * m["train_accuracy"] for num_examples, m in metrics + ] + val_losses = [num_examples * m["val_loss"] for num_examples, m in metrics] + val_accuracies = [num_examples * m["val_accuracy"] for num_examples, m in metrics] + + # Aggregate and return custom metric (weighted average) + return { + "train_loss": sum(train_losses) / sum(examples), + "train_accuracy": sum(train_accuracies) / sum(examples), + "val_loss": sum(val_losses) / sum(examples), + "val_accuracy": sum(val_accuracies) / sum(examples), + } + + +# Initialize model parameters +ndarrays = get_weights(Net()) +parameters = ndarrays_to_parameters(ndarrays) + + +# Define strategy +strategy = FedAvg( + fraction_fit=0.2, + fraction_evaluate=0.0, # Disable evaluation for demo purpose + min_fit_clients=20, + min_available_clients=20, + fit_metrics_aggregation_fn=weighted_average, + initial_parameters=parameters, +) +strategy = DifferentialPrivacyClientSideFixedClipping( + strategy, noise_multiplier=0.2, clipping_norm=10, num_sampled_clients=20 +) + + +app = ServerApp() + + +@app.main() +def main(driver: Driver, context: Context) -> None: + # Construct the LegacyContext + context = LegacyContext( + state=context.state, + config=ServerConfig(num_rounds=3), + strategy=strategy, + ) + + # Create the train/evaluate workflow + workflow = DefaultWorkflow( + fit_workflow=SecAggPlusWorkflow( + num_shares=7, + reconstruction_threshold=4, + ) + ) + + # Execute + workflow(driver, context) diff --git a/examples/fl-dp-sa/fl_dp_sa/task.py b/examples/fl-dp-sa/fl_dp_sa/task.py new file mode 100644 index 000000000000..3d506263d5a3 --- /dev/null +++ b/examples/fl-dp-sa/fl_dp_sa/task.py @@ -0,0 +1,110 @@ +"""fl_dp_sa: A Flower / PyTorch app.""" + +from collections import OrderedDict +from logging import INFO +from flwr_datasets import FederatedDataset + +import torch +import torch.nn as nn +import torch.nn.functional as F +from flwr.common.logger import log +from torch.utils.data import DataLoader +from torchvision.transforms import Compose, Normalize, ToTensor + + +DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +class Net(nn.Module): + """Model.""" + + def __init__(self) -> None: + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 6, 3, padding=1) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + batch_size = x.size(0) + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = x.view(batch_size, -1) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + return self.fc3(x) + + +def load_data(partition_id): + """Load partition CIFAR10 data.""" + fds = FederatedDataset(dataset="mnist", partitioners={"train": 100}) + partition = fds.load_partition(partition_id) + # Divide data on each node: 80% train, 20% test + partition_train_test = partition.train_test_split(test_size=0.2) + pytorch_transforms = Compose([ToTensor(), Normalize((0.5,), (0.5,))]) + + def apply_transforms(batch): + """Apply transforms to the partition from FederatedDataset.""" + batch["image"] = [pytorch_transforms(img) for img in batch["image"]] + return batch + + partition_train_test = partition_train_test.with_transform(apply_transforms) + trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True) + testloader = DataLoader(partition_train_test["test"], batch_size=32) + return trainloader, testloader + + +def train(net, trainloader, valloader, epochs, device): + """Train the model on the training set.""" + net.to(device) # move model to GPU if available + criterion = torch.nn.CrossEntropyLoss().to(device) + optimizer = torch.optim.Adam(net.parameters()) + net.train() + for _ in range(epochs): + for batch in trainloader: + images = batch["image"].to(device) + labels = batch["label"].to(device) + optimizer.zero_grad() + loss = criterion(net(images), labels) + loss.backward() + optimizer.step() + + train_loss, train_acc = test(net, trainloader) + val_loss, val_acc = test(net, valloader) + + results = { + "train_loss": train_loss, + "train_accuracy": train_acc, + "val_loss": val_loss, + "val_accuracy": val_acc, + } + return results + + +def test(net, testloader): + """Validate the model on the test set.""" + net.to(DEVICE) + criterion = torch.nn.CrossEntropyLoss() + correct, loss = 0, 0.0 + with torch.no_grad(): + for batch in testloader: + images = batch["image"].to(DEVICE) + labels = batch["label"].to(DEVICE) + outputs = net(images.to(DEVICE)) + labels = labels.to(DEVICE) + loss += criterion(outputs, labels).item() + correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() + accuracy = correct / len(testloader.dataset) + return loss, accuracy + + +def get_weights(net): + return [val.cpu().numpy() for _, val in net.state_dict().items()] + + +def set_weights(net, parameters): + params_dict = zip(net.state_dict().keys(), parameters) + state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) + net.load_state_dict(state_dict, strict=True) diff --git a/examples/fl-dp-sa/flower.toml b/examples/fl-dp-sa/flower.toml new file mode 100644 index 000000000000..ea2e98206791 --- /dev/null +++ b/examples/fl-dp-sa/flower.toml @@ -0,0 +1,13 @@ +[project] +name = "fl_dp_sa" +version = "1.0.0" +description = "" +license = "Apache-2.0" +authors = [ + "The Flower Authors ", +] +readme = "README.md" + +[flower.components] +serverapp = "fl_dp_sa.server:app" +clientapp = "fl_dp_sa.client:app" diff --git a/examples/fl-dp-sa/pyproject.toml b/examples/fl-dp-sa/pyproject.toml new file mode 100644 index 000000000000..d30fa4675e34 --- /dev/null +++ b/examples/fl-dp-sa/pyproject.toml @@ -0,0 +1,21 @@ +[build-system] +requires = ["poetry-core>=1.4.0"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "fl-dp-sa" +version = "0.1.0" +description = "" +license = "Apache-2.0" +authors = [ + "The Flower Authors ", +] +readme = "README.md" + +[tool.poetry.dependencies] +python = "^3.9" +# Mandatory dependencies +flwr-nightly = { version = "1.8.0.dev20240313", extras = ["simulation"] } +flwr-datasets = { version = "0.0.2", extras = ["vision"] } +torch = "2.2.1" +torchvision = "0.17.1" diff --git a/examples/fl-dp-sa/requirements.txt b/examples/fl-dp-sa/requirements.txt new file mode 100644 index 000000000000..ddb8a814447b --- /dev/null +++ b/examples/fl-dp-sa/requirements.txt @@ -0,0 +1,4 @@ +flwr-nightly[simulation]==1.8.0.dev20240313 +flwr-datasets[vision]==0.0.2 +torch==2.2.1 +torchvision==0.17.1 diff --git a/examples/flower-in-30-minutes/tutorial.ipynb b/examples/flower-in-30-minutes/tutorial.ipynb index 336ec4c19644..0e42cff924e8 100644 --- a/examples/flower-in-30-minutes/tutorial.ipynb +++ b/examples/flower-in-30-minutes/tutorial.ipynb @@ -7,11 +7,11 @@ "source": [ "Welcome to the 30 minutes Flower federated learning tutorial!\n", "\n", - "In this tutorial you will implement your first Federated Learning project using [Flower](https://flower.dev/).\n", + "In this tutorial you will implement your first Federated Learning project using [Flower](https://flower.ai/).\n", "\n", "🧑‍🏫 This tutorial starts at zero and expects no familiarity with federated learning. Only a basic understanding of data science and Python programming is assumed. A minimal understanding of ML is not required but if you already know about it, nothing is stopping your from modifying this code as you see fit!\n", "\n", - "> Star Flower on [GitHub ⭐️](https://github.com/adap/flower) and join the Flower community on Slack to connect, ask questions, and get help: [Join Slack 🌼](https://flower.dev/join-slack/). We'd love to hear from you in the #introductions channel! And if anything is unclear, head over to the #questions channel.\n", + "> Star Flower on [GitHub ⭐️](https://github.com/adap/flower) and join the Flower community on Slack to connect, ask questions, and get help: [Join Slack 🌼](https://flower.ai/join-slack/). We'd love to hear from you in the #introductions channel! And if anything is unclear, head over to the #questions channel.\n", "\n", "Let's get stated!" ] @@ -59,7 +59,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We will be using the _simulation_ model in Flower, which allows you to run a large number of clients without the overheads of manually managing devices. This is achieved via the `Virtual Client Engine`, the core component that runs [FL Simulations](https://flower.dev/docs/framework/how-to-run-simulations.html) with Flower. With simulation, you can dynamically scale your experiments whether you run the code on your laptop, a machine with a single GPU, a server with multiple GPUs os even on a cluster with multiple servers. The `Virtual Client Engine` handles everything transparently and it allows you to specify how many resources (e.g. CPU cores, GPU VRAM) should be assigned to each virtual client." + "We will be using the _simulation_ model in Flower, which allows you to run a large number of clients without the overheads of manually managing devices. This is achieved via the `Virtual Client Engine`, the core component that runs [FL Simulations](https://flower.ai/docs/framework/how-to-run-simulations.html) with Flower. With simulation, you can dynamically scale your experiments whether you run the code on your laptop, a machine with a single GPU, a server with multiple GPUs os even on a cluster with multiple servers. The `Virtual Client Engine` handles everything transparently and it allows you to specify how many resources (e.g. CPU cores, GPU VRAM) should be assigned to each virtual client." ] }, { @@ -69,7 +69,7 @@ "source": [ "## Install your ML framework\n", "\n", - "Flower is agnostic to your choice of ML Framework. Flower works with `PyTorch`, `Tensorflow`, `NumPy`, `🤗 Transformers`, `MXNet`, `JAX`, `scikit-learn`, `fastai`, `Pandas`. Flower also supports all major platforms: `iOS`, `Android` and plain `C++`. You can find a _quickstart- example for each of the above in the [Flower Repository](https://github.com/adap/flower/tree/main/examples) inside the `examples/` directory. And check the [Flower Documentation](https://flower.dev/docs/) for even more learning materials.\n", + "Flower is agnostic to your choice of ML Framework. Flower works with `PyTorch`, `Tensorflow`, `NumPy`, `🤗 Transformers`, `MXNet`, `JAX`, `scikit-learn`, `fastai`, `Pandas`. Flower also supports all major platforms: `iOS`, `Android` and plain `C++`. You can find a _quickstart- example for each of the above in the [Flower Repository](https://github.com/adap/flower/tree/main/examples) inside the `examples/` directory. And check the [Flower Documentation](https://flower.ai/docs/) for even more learning materials.\n", "\n", "In this tutorial we are going to use PyTorch, so let's install a recent version. In this tutorial we'll use a small model so using CPU only training will suffice (this will also prevent Colab from abruptly terminating your experiment if resource limits are exceeded)" ] @@ -631,11 +631,11 @@ "\n", "\n", "class FlowerClient(fl.client.NumPyClient):\n", - " def __init__(self, trainloader, vallodaer) -> None:\n", + " def __init__(self, trainloader, valloader) -> None:\n", " super().__init__()\n", "\n", " self.trainloader = trainloader\n", - " self.valloader = vallodaer\n", + " self.valloader = valloader\n", " self.model = Net(num_classes=10)\n", "\n", " def set_parameters(self, parameters):\n", @@ -714,7 +714,7 @@ "metadata": {}, "outputs": [], "source": [ - "def get_evalulate_fn(testloader):\n", + "def get_evaluate_fn(testloader):\n", " \"\"\"This is a function that returns a function. The returned\n", " function (i.e. `evaluate_fn`) will be executed by the strategy\n", " at the end of each round to evaluate the stat of the global\n", @@ -747,7 +747,7 @@ " fraction_fit=0.1, # let's sample 10% of the client each round to do local training\n", " fraction_evaluate=0.1, # after each round, let's sample 20% of the clients to asses how well the global model is doing\n", " min_available_clients=100, # total number of clients available in the experiment\n", - " evaluate_fn=get_evalulate_fn(testloader),\n", + " evaluate_fn=get_evaluate_fn(testloader),\n", ") # a callback to a function that the strategy can execute to evaluate the state of the global model on a centralised dataset" ] }, @@ -775,8 +775,8 @@ " \"\"\"Returns a FlowerClient containing the cid-th data partition\"\"\"\n", "\n", " return FlowerClient(\n", - " trainloader=trainloaders[int(cid)], vallodaer=valloaders[int(cid)]\n", - " )\n", + " trainloader=trainloaders[int(cid)], valloader=valloaders[int(cid)]\n", + " ).to_client()\n", "\n", " return client_fn\n", "\n", @@ -853,21 +853,21 @@ "\n", "Well, if you enjoyed this content, consider giving us a ⭐️ on GitHub -> https://github.com/adap/flower\n", "\n", - "* **[DOCS]** How about running your Flower clients on the GPU? find out how to do it in the [Flower Simulation Documentation](https://flower.dev/docs/framework/how-to-run-simulations.html)\n", + "* **[DOCS]** How about running your Flower clients on the GPU? find out how to do it in the [Flower Simulation Documentation](https://flower.ai/docs/framework/how-to-run-simulations.html)\n", "\n", "* **[VIDEO]** You can follow our [detailed line-by-line 9-videos tutorial](https://www.youtube.com/watch?v=cRebUIGB5RU&list=PLNG4feLHqCWlnj8a_E1A_n5zr2-8pafTB) about everything you need to know to design your own Flower Simulation pipelines\n", "\n", "* Check more advanced simulation examples the Flower GitHub:\n", "\n", " * Flower simulation with Tensorflow/Keras: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://github.com/adap/flower/tree/main/examples/simulation-tensorflow)\n", - " \n", + "\n", " * Flower simulation with Pytorch: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://github.com/adap/flower/tree/main/examples/simulation-pytorch)\n", "\n", - "* **[DOCS]** All Flower examples: https://flower.dev/docs/examples/\n", + "* **[DOCS]** All Flower examples: https://flower.ai/docs/examples/\n", "\n", "* **[VIDEO]** Our Youtube channel: https://www.youtube.com/@flowerlabs\n", "\n", - "Don't forget to join our Slack channel: https://flower.dev/join-slack/\n" + "Don't forget to join our Slack channel: https://flower.ai/join-slack/\n" ] }, { diff --git a/examples/flower-simulation-step-by-step-pytorch/Part-I/client.py b/examples/flower-simulation-step-by-step-pytorch/Part-I/client.py index b10c32d36c42..eac831ad1932 100644 --- a/examples/flower-simulation-step-by-step-pytorch/Part-I/client.py +++ b/examples/flower-simulation-step-by-step-pytorch/Part-I/client.py @@ -98,7 +98,7 @@ def client_fn(cid: str): trainloader=trainloaders[int(cid)], vallodaer=valloaders[int(cid)], num_classes=num_classes, - ) + ).to_client() # return the function to spawn client return client_fn diff --git a/examples/flower-simulation-step-by-step-pytorch/Part-I/main.py b/examples/flower-simulation-step-by-step-pytorch/Part-I/main.py index f5c76ab6dc99..f8124b9353f7 100644 --- a/examples/flower-simulation-step-by-step-pytorch/Part-I/main.py +++ b/examples/flower-simulation-step-by-step-pytorch/Part-I/main.py @@ -24,7 +24,7 @@ def main(cfg: DictConfig): save_path = HydraConfig.get().runtime.output_dir ## 2. Prepare your dataset - # When simulating FL workloads we have a lot of freedom on how the FL clients behave, + # When simulating FL runs we have a lot of freedom on how the FL clients behave, # what data they have, how much data, etc. This is not possible in real FL settings. # In simulation you'd often encounter two types of dataset: # * naturally partitioned, that come pre-partitioned by user id (e.g. FEMNIST, @@ -91,7 +91,7 @@ def main(cfg: DictConfig): "num_gpus": 0.0, }, # (optional) controls the degree of parallelism of your simulation. # Lower resources per client allow for more clients to run concurrently - # (but need to be set taking into account the compute/memory footprint of your workload) + # (but need to be set taking into account the compute/memory footprint of your run) # `num_cpus` is an absolute number (integer) indicating the number of threads a client should be allocated # `num_gpus` is a ratio indicating the portion of gpu memory that a client needs. ) diff --git a/examples/flower-simulation-step-by-step-pytorch/Part-II/client.py b/examples/flower-simulation-step-by-step-pytorch/Part-II/client.py index d269d4892a0e..7da9547d7362 100644 --- a/examples/flower-simulation-step-by-step-pytorch/Part-II/client.py +++ b/examples/flower-simulation-step-by-step-pytorch/Part-II/client.py @@ -75,6 +75,6 @@ def client_fn(cid: str): trainloader=trainloaders[int(cid)], vallodaer=valloaders[int(cid)], model_cfg=model_cfg, - ) + ).to_client() return client_fn diff --git a/examples/flower-simulation-step-by-step-pytorch/README.md b/examples/flower-simulation-step-by-step-pytorch/README.md index 55b8d837b090..beb8dd7f6f95 100644 --- a/examples/flower-simulation-step-by-step-pytorch/README.md +++ b/examples/flower-simulation-step-by-step-pytorch/README.md @@ -1,5 +1,7 @@ # Flower Simulation Step-by-Step +> Since this tutorial (and its video series) was put together, Flower has been updated a few times. As a result, some of the steps to construct the environment (see below) have been updated. Some parts of the code have also been updated. Overall, the content of this tutorial and how things work remains the same as in the video tutorials. + This directory contains the code developed in the `Flower Simulation` tutorial series on Youtube. You can find all the videos [here](https://www.youtube.com/playlist?list=PLNG4feLHqCWlnj8a_E1A_n5zr2-8pafTB) or clicking on the video preview below. - In `Part-I` (7 videos) we developed from scratch a complete Federated Learning pipeline for simulation using PyTorch. @@ -19,20 +21,17 @@ As presented in the video, we first need to create a Python environment. You are # I'm assuming you are running this on an Ubuntu 22.04 machine (GPU is not required) # create the environment -conda create -n flower_tutorial python=3.8 -y +conda create -n flower_tutorial python=3.9 -y # activate your environment (depending on how you installed conda you might need to use `conda activate ...` instead) source activate flower_tutorial # install PyToch (other versions would likely work) conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia -y -# conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 -c pytorch # If you don't have a GPU +# conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 -c pytorch -y # If you don't have a GPU -# install flower (for FL) and hydra (for configs) -pip install flwr==1.4.0 hydra-core==1.3.2 -# install ray -# you might see some warning messages after installing it (you can ignore them) -pip install ray==1.11.1 +# Install Flower and other dependencies +pip install -r requirements.txt ``` If you are running this on macOS with Apple Silicon (i.e. M1, M2), you'll need a different `grpcio` package if you see an error when running the code. To fix this do: diff --git a/examples/flower-simulation-step-by-step-pytorch/requirements.txt b/examples/flower-simulation-step-by-step-pytorch/requirements.txt new file mode 100644 index 000000000000..a322192ca711 --- /dev/null +++ b/examples/flower-simulation-step-by-step-pytorch/requirements.txt @@ -0,0 +1,2 @@ +flwr[simulation]>=1.0, <2.0 +hydra-core==1.3.2 \ No newline at end of file diff --git a/examples/flower-via-docker-compose/.gitignore b/examples/flower-via-docker-compose/.gitignore new file mode 100644 index 000000000000..de5b2e7692bf --- /dev/null +++ b/examples/flower-via-docker-compose/.gitignore @@ -0,0 +1,17 @@ +# ignore __pycache__ directories +__pycache__/ + +# ignore .pyc files +*.pyc + +# ignore .vscode directory +.vscode/ + +# ignore .npz files +*.npz + +# ignore .csv files +*.csv + +# ignore docker-compose.yaml file +docker-compose.yml \ No newline at end of file diff --git a/examples/flower-via-docker-compose/Dockerfile b/examples/flower-via-docker-compose/Dockerfile new file mode 100644 index 000000000000..ee6fee3103a5 --- /dev/null +++ b/examples/flower-via-docker-compose/Dockerfile @@ -0,0 +1,19 @@ +# Use an official Python runtime as a parent image +FROM python:3.10-slim-buster + +# Set the working directory in the container to /app +WORKDIR /app + +# Copy the requirements file into the container +COPY ./requirements.txt /app/requirements.txt + +# Install gcc and other dependencies +RUN apt-get update && apt-get install -y \ + gcc \ + python3-dev && \ + rm -rf /var/lib/apt/lists/* + +# Install any needed packages specified in requirements.txt +RUN pip install -r requirements.txt + + diff --git a/examples/flower-via-docker-compose/README.md b/examples/flower-via-docker-compose/README.md new file mode 100644 index 000000000000..3ef1ac37bcda --- /dev/null +++ b/examples/flower-via-docker-compose/README.md @@ -0,0 +1,254 @@ +# Leveraging Flower and Docker for Device Heterogeneity Management in Federated Learning + +

+ Flower Website + Docker Logo +

+ +## Introduction + +In this example, we tackle device heterogeneity in federated learning, arising from differences in memory and CPU capabilities across devices. This diversity affects training efficiency and inclusivity. Our strategy includes simulating this heterogeneity by setting CPU and memory limits in a Docker setup, using a custom Docker compose generator script. This approach creates a varied training environment and enables us to develop strategies to manage these disparities effectively. + +## Handling Device Heterogeneity + +1. **System Metrics Access**: + + - Effective management of device heterogeneity begins with monitoring system metrics of each container. We integrate the following services to achieve this: + - **Cadvisor**: Collects comprehensive metrics from each Docker container. + - **Prometheus**: Using `prometheus.yaml` for configuration, it scrapes data from Cadvisor at scheduled intervals, serving as a robust time-series database. Users can access the Prometheus UI at `http://localhost:9090` to create and run queries using PromQL, allowing for detailed insight into container performance. + +2. **Mitigating Heterogeneity**: + + - In this basic use case, we address device heterogeneity by establishing rules tailored to each container's system capabilities. This involves modifying training parameters, such as batch sizes and learning rates, based on each device's memory capacity and CPU availability. These settings are specified in the `client_configs` array in the `create_docker_compose` script. For example: + + ```python + client_configs = [ + {"mem_limit": "3g", "batch_size": 32, "cpus": 4, "learning_rate": 0.001}, + {"mem_limit": "6g", "batch_size": 256, "cpus": 1, "learning_rate": 0.05}, + {"mem_limit": "4g", "batch_size": 64, "cpus": 3, "learning_rate": 0.02}, + {"mem_limit": "5g", "batch_size": 128, "cpus": 2.5, "learning_rate": 0.09}, + ] + ``` + +## Prerequisites + +Docker must be installed and the Docker daemon running on your server. If you don't already have Docker installed, you can get [installation instructions for your specific Linux distribution or macOS from Docker](https://docs.docker.com/engine/install/). Besides Docker, the only extra requirement is having Python installed. You don't need to create a new environment for this example since all dependencies will be installed inside Docker containers automatically. + +## Running the Example + +Running this example is easy. For a more detailed step-by-step guide, including more useful material, refer to the detailed guide in the following section. + +```bash + +# Generate docker compose file +python helpers/generate_docker_compose.py # by default will configure to use 2 clients for 100 rounds + +# Build docker images +docker-compose build + +# Launch everything +docker-compose up +``` + +On your favourite browser, go to `http://localhost:3000` to see the Graphana dashboard showing system-level and application-level metrics. + +To stop all containers, open a new terminal and `cd` into this directory, then run `docker-compose down`. Alternatively, you can do `ctrl+c` on the same terminal and then run `docker-compose down` to ensure everything is terminated. + +## Running the Example (detailed) + +### Step 1: Configure Docker Compose + +Execute the following command to run the `helpers/generate_docker_compose.py` script. This script creates the docker-compose configuration needed to set up the environment. + +```bash +python helpers/generate_docker_compose.py +``` + +Within the script, specify the number of clients (`total_clients`) and resource limitations for each client in the `client_configs` array. You can adjust the number of rounds by passing `--num_rounds` to the above command. + +### Step 2: Build and Launch Containers + +1. **Execute Initialization Script**: + + - To build the Docker images and start the containers, use the following command: + + ```bash + # this is the only command you need to execute to run the entire example + docker-compose up + ``` + + - If you make any changes to the Dockerfile or other configuration files, you should rebuild the images to reflect these changes. This can be done by adding the `--build` flag to the command: + + ```bash + docker-compose up --build + ``` + + - The `--build` flag instructs Docker Compose to rebuild the images before starting the containers, ensuring that any code or configuration changes are included. + + - To stop all services, you have two options: + + - Run `docker-compose down` in another terminal if you are in the same directory. This command will stop and remove the containers, networks, and volumes created by `docker-compose up`. + - Press `Ctrl+C` once in the terminal where `docker-compose up` is running. This will stop the containers but won't remove them or the networks and volumes they use. + +2. **Services Startup**: + + - Several services will automatically launch as defined in your `docker-compose.yml` file: + + - **Monitoring Services**: Prometheus for metrics collection, Cadvisor for container monitoring, and Grafana for data visualization. + - **Flower Federated Learning Environment**: The Flower server and client containers are initialized and start running. + + - After launching the services, verify that all Docker containers are running correctly by executing the `docker ps` command. Here's an example output: + + ```bash + ➜ ~ docker ps + CONTAINER ID IMAGE COMMAND CREATED STATUS PORTS NAMES + 9f05820eba45 flower-via-docker-compose-client2 "python client.py --…" 50 seconds ago Up 48 seconds 0.0.0.0:6002->6002/tcp client2 + a0333715d504 flower-via-docker-compose-client1 "python client.py --…" 50 seconds ago Up 48 seconds 0.0.0.0:6001->6001/tcp client1 + 0da2bf735965 flower-via-docker-compose-server "python server.py --…" 50 seconds ago Up 48 seconds 0.0.0.0:6000->6000/tcp, 0.0.0.0:8000->8000/tcp, 0.0.0.0:8265->8265/tcp server + c57ef50657ae grafana/grafana:latest "/run.sh --config=/e…" 50 seconds ago Up 49 seconds 0.0.0.0:3000->3000/tcp grafana + 4f274c2083dc prom/prometheus:latest "/bin/prometheus --c…" 50 seconds ago Up 49 seconds 0.0.0.0:9090->9090/tcp prometheus + e9f4c9644a1c gcr.io/cadvisor/cadvisor:v0.47.0 "/usr/bin/cadvisor -…" 50 seconds ago Up 49 seconds 0.0.0.0:8080->8080/tcp cadvisor + ``` + + - To monitor the resource utilization of your containers in real-time and see the limits imposed in the Docker Compose file, you can use the `docker stats` command. This command provides a live stream of container CPU, memory, and network usage statistics. + + ```bash + ➜ ~ docker stats + CONTAINER ID NAME CPU % MEM USAGE / LIMIT MEM % NET I/O BLOCK I/O PIDS + 9f05820eba45 client2 104.44% 1.968GiB / 6GiB 32.80% 148MB / 3.22MB 0B / 284MB 82 + a0333715d504 client1 184.69% 1.498GiB / 3GiB 49.92% 149MB / 2.81MB 1.37MB / 284MB 82 + 0da2bf735965 server 0.12% 218.5MiB / 15.61GiB 1.37% 1.47MB / 2.89MB 2.56MB / 2.81MB 45 + c57ef50657ae grafana 0.24% 96.19MiB / 400MiB 24.05% 18.9kB / 3.79kB 77.8kB / 152kB 20 + 4f274c2083dc prometheus 1.14% 52.73MiB / 500MiB 10.55% 6.79MB / 211kB 1.02MB / 1.31MB 15 + e9f4c9644a1c cadvisor 7.31% 32.14MiB / 500MiB 6.43% 139kB / 6.66MB 500kB / 0B 18 + ``` + +3. **Automated Grafana Configuration**: + + - Grafana is configured to load pre-defined data sources and dashboards for immediate monitoring, facilitated by provisioning files. The provisioning files include `prometheus-datasource.yml` for data sources, located in the `./config/provisioning/datasources` directory, and `dashboard_index.json` for dashboards, in the `./config/provisioning/dashboards` directory. The `grafana.ini` file is also tailored to enhance user experience: + - **Admin Credentials**: We provide default admin credentials in the `grafana.ini` configuration, which simplifies access by eliminating the need for users to go through the initial login process. + - **Default Dashboard Path**: A default dashboard path is set in `grafana.ini` to ensure that the dashboard with all the necessary panels is rendered when Grafana is accessed. + + These files and settings are directly mounted into the Grafana container via Docker Compose volume mappings. This setup guarantees that upon startup, Grafana is pre-configured for monitoring, requiring no additional manual setup. + +4. **Begin Training Process**: + + - The federated learning training automatically begins once all client containers are successfully connected to the Flower server. This synchronizes the learning process across all participating clients. + +By following these steps, you will have a fully functional federated learning environment with device heterogeneity and monitoring capabilities. + +## Model Training and Dataset Integration + +### Data Pipeline with FLWR-Datasets + +We have integrated [`flwr-datasets`](https://flower.ai/docs/datasets/) into our data pipeline, which is managed within the `load_data.py` file in the `helpers/` directory. This script facilitates standardized access to datasets across the federated network and incorporates a `data_sampling_percentage` argument. This argument allows users to specify the percentage of the dataset to be used for training and evaluation, accommodating devices with lower memory capabilities to prevent Out-of-Memory (OOM) errors. + +### Model Selection and Dataset + +For the federated learning system, we have selected the MobileNet model due to its efficiency in image classification tasks. The model is trained and evaluated on the CIFAR-10 dataset. The combination of MobileNet and CIFAR-10 is ideal for demonstrating the capabilities of our federated learning solution in a heterogeneous device environment. + +- **MobileNet**: A streamlined architecture for mobile and embedded devices that balances performance and computational cost. +- **CIFAR-10 Dataset**: A standard benchmark dataset for image classification, containing various object classes that pose a comprehensive challenge for the learning model. + +By integrating these components, our framework is well-prepared to handle the intricacies of training over a distributed network with varying device capabilities and data availability. + +## Visualizing with Grafana + +### Access Grafana Dashboard + +Visit `http://localhost:3000` to enter Grafana. The automated setup ensures that you're greeted with a series of pre-configured dashboards, including the default screen with a comprehensive set of graphs. These dashboards are ready for immediate monitoring and can be customized to suit your specific requirements. + +### Dashboard Configuration + +The `dashboard_index.json` file, located in the `./config/provisioning/dashboards` directory, serves as the backbone of our Grafana dashboard's configuration. It defines the structure and settings of the dashboard panels, which are rendered when you access Grafana. This JSON file contains the specifications for various panels such as model accuracy, CPU usage, memory utilization, and network traffic. Each panel's configuration includes the data source, queries, visualization type, and other display settings like thresholds and colors. + +For instance, in our project setup, the `dashboard_index.json` configures a panel to display the model's accuracy over time using a time-series graph, and another panel to show the CPU usage across clients using a graph that plots data points as they are received. This file is fundamental for creating a customized and informative dashboard that provides a snapshot of the federated learning system's health and performance metrics. + +By modifying the `dashboard_index.json` file, users can tailor the Grafana dashboard to include additional metrics or change the appearance and behavior of existing panels to better fit their monitoring requirements. + +### Grafana Default Dashboard + +Below is the default Grafana dashboard that users will see upon accessing Grafana: + +grafana_home_screen + +This comprehensive dashboard provides insights into various system metrics across client-server containers. It includes visualizations such as: + +- **Application Metrics**: The "Model Accuracy" graph shows an upward trend as rounds of training progress, which is a positive indicator of the model learning and improving over time. Conversely, the "Model Loss" graph trends downward, suggesting that the model is becoming more precise and making fewer mistakes as it trains. + +- **CPU Usage**: The sharp spikes in the red graph, representing "client1", indicate peak CPU usage, which is considerably higher than that of "client2" (blue graph). This difference is due to "client1" being allocated more computing resources (up to 4 CPU cores) compared to "client2", which is limited to just 1 CPU core, hence the more subdued CPU usage pattern. + +- **Memory Utilization**: Both clients are allocated a similar amount of memory, reflected in the nearly same lines for memory usage. This uniform allocation allows for a straightforward comparison of how each client manages memory under similar conditions. + +- **Network Traffic**: Monitor incoming and outgoing network traffic to each client, which is crucial for understanding data exchange volumes during federated learning cycles. + +Together, these metrics paint a detailed picture of the federated learning operation, showcasing resource usage and model performance. Such insights are invaluable for system optimization, ensuring balanced load distribution and efficient model training. + +## Comprehensive Monitoring System Integration + +### Capturing Container Metrics with cAdvisor + +cAdvisor is seamlessly integrated into our monitoring setup to capture a variety of system and container metrics, such as CPU, memory, and network usage. These metrics are vital for analyzing the performance and resource consumption of the containers in the federated learning environment. + +### Custom Metrics: Setup and Monitoring via Prometheus + +In addition to the standard metrics captured by cAdvisor, we have implemented a process to track custom metrics like model's accuracy and loss within Grafana, using Prometheus as the backbone for metric collection. + +1. **Prometheus Client Installation**: + + - We began by installing the `prometheus_client` library in our Python environment, enabling us to define and expose custom metrics that Prometheus can scrape. + +2. **Defining Metrics in Server Script**: + + - Within our `server.py` script, we have established two key Prometheus Gauge metrics, specifically tailored for monitoring our federated learning model: `model_accuracy` and `model_loss`. These custom gauges are instrumental in capturing the most recent values of the model's accuracy and loss, which are essential metrics for evaluating the model's performance. The gauges are defined as follows: + + ```python + from prometheus_client import Gauge + + accuracy_gauge = Gauge('model_accuracy', 'Current accuracy of the global model') + loss_gauge = Gauge('model_loss', 'Current loss of the global model') + ``` + +3. **Exposing Metrics via HTTP Endpoint**: + + - We leveraged the `start_http_server` function from the `prometheus_client` library to launch an HTTP server on port 8000. This server provides the `/metrics` endpoint, where the custom metrics are accessible for Prometheus scraping. The function is called at the end of the `main` method in `server.py`: + + ```python + start_http_server(8000) + ``` + +4. **Updating Metrics Recording Strategy**: + + - The core of our metrics tracking lies in the `strategy.py` file, particularly within the `aggregate_evaluate` method. This method is crucial as it's where the federated learning model's accuracy and loss values are computed after each round of training with the aggregated data from all clients. + + ```python + self.accuracy_gauge.set(accuracy_aggregated) + self.loss_gauge.set(loss_aggregated) + ``` + +5. **Configuring Prometheus Scraping**: + + - In the `prometheus.yml` file, under `scrape_configs`, we configured a new job to scrape the custom metrics from the HTTP server. This setup includes the job's name, the scraping interval, and the target server's URL. + +### Visualizing the Monitoring Architecture + +The image below depicts the Prometheus scraping process as it is configured in our monitoring setup. Within this architecture: + +- The "Prometheus server" is the central component that retrieves and stores metrics. +- "cAdvisor" and the "HTTP server" we set up to expose our custom metrics are represented as "Prometheus targets" in the diagram. cAdvisor captures container metrics, while the HTTP server serves our custom `model_accuracy` and `model_loss` metrics at the `/metrics` endpoint. +- These targets are periodically scraped by the Prometheus server, aggregating data from both system-level and custom performance metrics. +- The aggregated data is then made available to the "Prometheus web UI" and "Grafana," as shown, enabling detailed visualization and analysis through the Grafana dashboard. + +prometheus-architecture + +By incorporating these steps, we have enriched our monitoring capabilities to not only include system-level metrics but also critical performance indicators of our federated learning model. This approach is pivotal for understanding and improving the learning process. Similarly, you can apply this methodology to track any other metric that you find interesting or relevant to your specific needs. This flexibility allows for a comprehensive and customized monitoring environment, tailored to the unique aspects and requirements of your federated learning system. + +## Additional Resources + +- **Grafana Tutorials**: Explore a variety of tutorials on Grafana at [Grafana Tutorials](https://grafana.com/tutorials/). +- **Prometheus Overview**: Learn more about Prometheus at their [official documentation](https://prometheus.io/docs/introduction/overview/). +- **cAdvisor Guide**: For information on monitoring Docker containers with cAdvisor, see this [Prometheus guide](https://prometheus.io/docs/guides/cadvisor/). + +## Conclusion + +This project serves as a foundational example of managing device heterogeneity within the federated learning context, employing the Flower framework alongside Docker, Prometheus, and Grafana. It's designed to be a starting point for users to explore and further adapt to the complexities of device heterogeneity in federated learning environments. diff --git a/examples/flower-via-docker-compose/client.py b/examples/flower-via-docker-compose/client.py new file mode 100644 index 000000000000..c894143532a1 --- /dev/null +++ b/examples/flower-via-docker-compose/client.py @@ -0,0 +1,110 @@ +import os +import argparse +import flwr as fl +import tensorflow as tf +import logging +from helpers.load_data import load_data +import os +from model.model import Model + +logging.basicConfig(level=logging.INFO) # Configure logging +logger = logging.getLogger(__name__) # Create logger for the module + +# Make TensorFlow log less verbose +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + +# Parse command line arguments +parser = argparse.ArgumentParser(description="Flower client") + +parser.add_argument( + "--server_address", type=str, default="server:8080", help="Address of the server" +) +parser.add_argument( + "--batch_size", type=int, default=32, help="Batch size for training" +) +parser.add_argument( + "--learning_rate", type=float, default=0.1, help="Learning rate for the optimizer" +) +parser.add_argument("--client_id", type=int, default=1, help="Unique ID for the client") +parser.add_argument( + "--total_clients", type=int, default=2, help="Total number of clients" +) +parser.add_argument( + "--data_percentage", type=float, default=0.5, help="Portion of client data to use" +) + +args = parser.parse_args() + +# Create an instance of the model and pass the learning rate as an argument +model = Model(learning_rate=args.learning_rate) + +# Compile the model +model.compile() + + +class Client(fl.client.NumPyClient): + def __init__(self, args): + self.args = args + + logger.info("Preparing data...") + (x_train, y_train), (x_test, y_test) = load_data( + data_sampling_percentage=self.args.data_percentage, + client_id=self.args.client_id, + total_clients=self.args.total_clients, + ) + + self.x_train = x_train + self.y_train = y_train + self.x_test = x_test + self.y_test = y_test + + def get_parameters(self, config): + # Return the parameters of the model + return model.get_model().get_weights() + + def fit(self, parameters, config): + # Set the weights of the model + model.get_model().set_weights(parameters) + + # Train the model + history = model.get_model().fit( + self.x_train, self.y_train, batch_size=self.args.batch_size + ) + + # Calculate evaluation metric + results = { + "accuracy": float(history.history["accuracy"][-1]), + } + + # Get the parameters after training + parameters_prime = model.get_model().get_weights() + + # Directly return the parameters and the number of examples trained on + return parameters_prime, len(self.x_train), results + + def evaluate(self, parameters, config): + # Set the weights of the model + model.get_model().set_weights(parameters) + + # Evaluate the model and get the loss and accuracy + loss, accuracy = model.get_model().evaluate( + self.x_test, self.y_test, batch_size=self.args.batch_size + ) + + # Return the loss, the number of examples evaluated on and the accuracy + return float(loss), len(self.x_test), {"accuracy": float(accuracy)} + + +# Function to Start the Client +def start_fl_client(): + try: + client = Client(args).to_client() + fl.client.start_client(server_address=args.server_address, client=client) + except Exception as e: + logger.error("Error starting FL client: %s", e) + return {"status": "error", "message": str(e)} + + +if __name__ == "__main__": + # Call the function to start the client + start_fl_client() diff --git a/examples/flower-via-docker-compose/config/grafana.ini b/examples/flower-via-docker-compose/config/grafana.ini new file mode 100644 index 000000000000..775f39d7ec22 --- /dev/null +++ b/examples/flower-via-docker-compose/config/grafana.ini @@ -0,0 +1,12 @@ +[security] +allow_embedding = true +admin_user = admin +admin_password = admin + +[dashboards] +default_home_dashboard_path = /etc/grafana/provisioning/dashboards/dashboard_index.json + +[auth.anonymous] +enabled = true +org_name = Main Org. +org_role = Admin diff --git a/examples/flower-via-docker-compose/config/prometheus.yml b/examples/flower-via-docker-compose/config/prometheus.yml new file mode 100644 index 000000000000..46cf07b9dcee --- /dev/null +++ b/examples/flower-via-docker-compose/config/prometheus.yml @@ -0,0 +1,19 @@ + +global: + scrape_interval: 1s + evaluation_interval: 1s + +rule_files: +scrape_configs: + - job_name: 'cadvisor' + scrape_interval: 1s + metrics_path: '/metrics' + static_configs: + - targets: ['cadvisor:8080'] + labels: + group: 'cadvisor' + - job_name: 'server_metrics' + scrape_interval: 1s + metrics_path: '/metrics' + static_configs: + - targets: ['server:8000'] \ No newline at end of file diff --git a/examples/flower-via-docker-compose/config/provisioning/dashboards/dashboard_index.json b/examples/flower-via-docker-compose/config/provisioning/dashboards/dashboard_index.json new file mode 100644 index 000000000000..b52f19c57508 --- /dev/null +++ b/examples/flower-via-docker-compose/config/provisioning/dashboards/dashboard_index.json @@ -0,0 +1,1051 @@ +{ + "annotations": { + "list": [ + { + "builtIn": 1, + "datasource": { + "type": "datasource", + "uid": "grafana" + }, + "enable": true, + "hide": true, + "iconColor": "rgba(0, 211, 255, 1)", + "name": "Annotations & Alerts", + "type": "dashboard" + } + ] + }, + "description": "Simple exporter for cadvisor only", + "editable": true, + "fiscalYearStartMonth": 0, + "gnetId": 14282, + "graphTooltip": 0, + "id": 12, + "links": [], + "liveNow": false, + "panels": [ + { + "gridPos": { + "h": 1, + "w": 24, + "x": 0, + "y": 0 + }, + "id": 22, + "title": "Application metrics", + "type": "row" + }, + { + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "description": "Averaged federated accuracy across clients", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "smooth", + "lineStyle": { + "fill": "solid" + }, + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 1 + }, + "id": 23, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": false + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "disableTextWrap": false, + "editorMode": "builder", + "expr": "model_accuracy", + "fullMetaSearch": false, + "includeNullMetadata": true, + "instant": false, + "legendFormat": "__auto", + "range": true, + "refId": "A", + "useBackend": false + } + ], + "title": "Model Accuracy", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "description": "Averaged Federated Loss across clients", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "smooth", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 1 + }, + "id": 21, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": false + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "disableTextWrap": false, + "editorMode": "builder", + "expr": "model_loss", + "fullMetaSearch": false, + "includeNullMetadata": true, + "instant": false, + "legendFormat": "__auto", + "range": true, + "refId": "A", + "useBackend": false + } + ], + "title": "Model Loss", + "type": "timeseries" + }, + { + "collapsed": false, + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "gridPos": { + "h": 1, + "w": 24, + "x": 0, + "y": 9 + }, + "id": 8, + "panels": [], + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "refId": "A" + } + ], + "title": "CPU", + "type": "row" + }, + { + "aliasColors": { + "client1": "red", + "client2": "blue", + "server": "yellow" + }, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "fill": 1, + "fillGradient": 0, + "gridPos": { + "h": 7, + "w": 24, + "x": 0, + "y": 10 + }, + "hiddenSeries": false, + "id": 15, + "legend": { + "alignAsTable": true, + "avg": true, + "current": false, + "max": true, + "min": false, + "rightSide": true, + "show": true, + "total": false, + "values": true + }, + "lines": true, + "linewidth": 1, + "nullPointMode": "null as zero", + "options": { + "alertThreshold": true + }, + "percentage": false, + "pluginVersion": "10.2.2", + "pointradius": 2, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "editorMode": "code", + "expr": "sum(rate(container_cpu_usage_seconds_total{instance=~\"$host\",name=~\"$container\",name=~\".+\", name !~ \"(prometheus|cadvisor|grafana)\"}[10s])) by (name) *100", + "hide": false, + "interval": "", + "legendFormat": "{{name}}", + "range": true, + "refId": "A" + } + ], + "thresholds": [], + "timeRegions": [], + "title": "CPU Usage", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "mode": "time", + "show": true, + "values": [] + }, + "yaxes": [ + { + "$$hashKey": "object:606", + "format": "percent", + "logBase": 1, + "show": true + }, + { + "$$hashKey": "object:607", + "format": "short", + "logBase": 1, + "show": true + } + ], + "yaxis": { + "align": false + } + }, + { + "collapsed": false, + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "gridPos": { + "h": 1, + "w": 24, + "x": 0, + "y": 17 + }, + "id": 11, + "panels": [], + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "refId": "A" + } + ], + "title": "Memory", + "type": "row" + }, + { + "aliasColors": { + "client1": "red", + "client2": "blue", + "server": "yellow" + }, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "fill": 1, + "fillGradient": 0, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 18 + }, + "hiddenSeries": false, + "id": 9, + "legend": { + "alignAsTable": true, + "avg": true, + "current": false, + "max": true, + "min": false, + "rightSide": true, + "show": true, + "total": false, + "values": true + }, + "lines": true, + "linewidth": 1, + "nullPointMode": "null as zero", + "options": { + "alertThreshold": true + }, + "percentage": false, + "pluginVersion": "10.2.2", + "pointradius": 2, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "editorMode": "code", + "expr": "sum(container_memory_rss{instance=~\"$host\",name=~\"$container\",name=~\".+\", name !~ \"(prometheus|cadvisor|grafana)\"}) by (name)", + "hide": false, + "interval": "", + "legendFormat": "{{name}}", + "range": true, + "refId": "A" + } + ], + "thresholds": [], + "timeRegions": [], + "title": "Memory Usage", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "mode": "time", + "show": true, + "values": [] + }, + "yaxes": [ + { + "$$hashKey": "object:606", + "format": "bytes", + "logBase": 1, + "show": true + }, + { + "$$hashKey": "object:607", + "format": "short", + "logBase": 1, + "show": true + } + ], + "yaxis": { + "align": false + } + }, + { + "aliasColors": { + "client1": "red", + "client2": "blue", + "server": "yellow" + }, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "fill": 1, + "fillGradient": 0, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 18 + }, + "hiddenSeries": false, + "id": 14, + "legend": { + "alignAsTable": true, + "avg": true, + "current": false, + "max": true, + "min": false, + "rightSide": true, + "show": true, + "total": false, + "values": true + }, + "lines": true, + "linewidth": 1, + "nullPointMode": "null as zero", + "options": { + "alertThreshold": true + }, + "percentage": false, + "pluginVersion": "10.2.2", + "pointradius": 2, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "editorMode": "code", + "expr": "sum(container_memory_cache{instance=~\"$host\",name=~\"$container\",name=~\".+\", name !~ \"(prometheus|cadvisor|grafana)\"}) by (name)", + "hide": false, + "interval": "", + "legendFormat": "{{name}}", + "range": true, + "refId": "A" + } + ], + "thresholds": [], + "timeRegions": [], + "title": "Memory Cached", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "mode": "time", + "show": true, + "values": [] + }, + "yaxes": [ + { + "$$hashKey": "object:606", + "format": "bytes", + "logBase": 1, + "show": true + }, + { + "$$hashKey": "object:607", + "format": "short", + "logBase": 1, + "show": true + } + ], + "yaxis": { + "align": false + } + }, + { + "collapsed": false, + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "gridPos": { + "h": 1, + "w": 24, + "x": 0, + "y": 26 + }, + "id": 2, + "panels": [], + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "refId": "A" + } + ], + "title": "Network", + "type": "row" + }, + { + "aliasColors": { + "client1": "red", + "client2": "blue", + "server": "yellow" + }, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "fill": 1, + "fillGradient": 0, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 27 + }, + "hiddenSeries": false, + "id": 4, + "legend": { + "alignAsTable": true, + "avg": true, + "current": false, + "hideEmpty": false, + "hideZero": false, + "max": true, + "min": false, + "rightSide": true, + "show": true, + "total": false, + "values": true + }, + "lines": true, + "linewidth": 1, + "nullPointMode": "null", + "options": { + "alertThreshold": true + }, + "percentage": false, + "pluginVersion": "10.2.2", + "pointradius": 2, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "editorMode": "code", + "expr": "sum(rate(container_network_receive_bytes_total{instance=~\"$host\",name=~\"$container\",name=~\".+\", name !~ \"(prometheus|cadvisor|grafana)\"}[10s])) by (name)", + "hide": false, + "interval": "", + "legendFormat": "{{name}}", + "range": true, + "refId": "A" + } + ], + "thresholds": [], + "timeRegions": [], + "title": "Received Network Traffic", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "mode": "time", + "show": true, + "values": [] + }, + "yaxes": [ + { + "$$hashKey": "object:674", + "format": "Bps", + "logBase": 1, + "show": true + }, + { + "$$hashKey": "object:675", + "format": "short", + "logBase": 1, + "show": true + } + ], + "yaxis": { + "align": false + } + }, + { + "aliasColors": { + "client1": "red", + "client2": "blue", + "server": "yellow" + }, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "fill": 1, + "fillGradient": 0, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 27 + }, + "hiddenSeries": false, + "id": 6, + "legend": { + "alignAsTable": true, + "avg": true, + "current": false, + "max": true, + "min": false, + "rightSide": true, + "show": true, + "total": false, + "values": true + }, + "lines": true, + "linewidth": 1, + "nullPointMode": "null", + "options": { + "alertThreshold": true + }, + "percentage": false, + "pluginVersion": "10.2.2", + "pointradius": 2, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "editorMode": "code", + "expr": "sum(rate(container_network_transmit_bytes_total{instance=~\"$host\",name=~\"$container\",name=~\".+\", name !~ \"(prometheus|cadvisor|grafana)\"}[10s])) by (name)", + "interval": "", + "legendFormat": "{{name}}", + "range": true, + "refId": "A" + } + ], + "thresholds": [], + "timeRegions": [], + "title": "Sent Network Traffic", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "mode": "time", + "show": true, + "values": [] + }, + "yaxes": [ + { + "$$hashKey": "object:832", + "format": "Bps", + "logBase": 1, + "show": true + }, + { + "$$hashKey": "object:833", + "format": "short", + "logBase": 1, + "show": true + } + ], + "yaxis": { + "align": false + } + }, + { + "collapsed": false, + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "gridPos": { + "h": 1, + "w": 24, + "x": 0, + "y": 35 + }, + "id": 19, + "panels": [], + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "refId": "A" + } + ], + "title": "Misc", + "type": "row" + }, + { + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "fieldConfig": { + "defaults": { + "custom": { + "align": "auto", + "cellOptions": { + "type": "auto" + }, + "filterable": false, + "inspect": false + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green" + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [ + { + "matcher": { + "id": "byName", + "options": "id" + }, + "properties": [ + { + "id": "custom.width", + "value": 260 + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "Running" + }, + "properties": [ + { + "id": "unit", + "value": "d" + }, + { + "id": "decimals", + "value": 1 + }, + { + "id": "custom.cellOptions", + "value": { + "type": "color-text" + } + }, + { + "id": "color", + "value": { + "fixedColor": "dark-green", + "mode": "fixed" + } + } + ] + } + ] + }, + "gridPos": { + "h": 10, + "w": 24, + "x": 0, + "y": 36 + }, + "id": 17, + "options": { + "cellHeight": "sm", + "footer": { + "countRows": false, + "fields": "", + "reducer": [ + "sum" + ], + "show": false + }, + "showHeader": true, + "sortBy": [] + }, + "pluginVersion": "10.2.2", + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "expr": "(time() - container_start_time_seconds{instance=~\"$host\",name=~\"$container\",name=~\".+\"})/86400", + "format": "table", + "instant": true, + "interval": "", + "legendFormat": "{{name}}", + "refId": "A" + } + ], + "title": "Containers Info", + "transformations": [ + { + "id": "filterFieldsByName", + "options": { + "include": { + "names": [ + "container_label_com_docker_compose_project", + "container_label_com_docker_compose_project_working_dir", + "image", + "instance", + "name", + "Value", + "container_label_com_docker_compose_service" + ] + } + } + }, + { + "id": "organize", + "options": { + "excludeByName": {}, + "indexByName": {}, + "renameByName": { + "Value": "Running", + "container_label_com_docker_compose_project": "Label", + "container_label_com_docker_compose_project_working_dir": "Working dir", + "container_label_com_docker_compose_service": "Service", + "image": "Registry Image", + "instance": "Instance", + "name": "Name" + } + } + } + ], + "type": "table" + } + ], + "refresh": "auto", + "schemaVersion": 38, + "tags": [], + "templating": { + "list": [ + { + "allValue": ".*", + "current": { + "selected": false, + "text": "All", + "value": "$__all" + }, + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "definition": "label_values({__name__=~\"container.*\"},instance)", + "hide": 0, + "includeAll": true, + "label": "Host", + "multi": false, + "name": "host", + "options": [], + "query": { + "query": "label_values({__name__=~\"container.*\"},instance)", + "refId": "Prometheus-host-Variable-Query" + }, + "refresh": 1, + "regex": "", + "skipUrlSync": false, + "sort": 5, + "tagValuesQuery": "", + "tagsQuery": "", + "type": "query", + "useTags": false + }, + { + "allValue": ".*", + "current": { + "selected": false, + "text": "All", + "value": "$__all" + }, + "datasource": { + "type": "prometheus", + "uid": "db69454e-e558-479e-b4fc-80db52bf91da" + }, + "definition": "label_values({__name__=~\"container.*\", instance=~\"$host\"},name)", + "hide": 0, + "includeAll": true, + "label": "Container", + "multi": false, + "name": "container", + "options": [], + "query": { + "query": "label_values({__name__=~\"container.*\", instance=~\"$host\"},name)", + "refId": "Prometheus-container-Variable-Query" + }, + "refresh": 1, + "regex": "", + "skipUrlSync": false, + "sort": 0, + "tagValuesQuery": "", + "tagsQuery": "", + "type": "query", + "useTags": false + } + ] + }, + "time": { + "from": "now-15m", + "to": "now" + }, + "timepicker": {}, + "timezone": "", + "title": "Cadvisor exporter Copy", + "uid": "fcf2a8da-792c-4b9f-a22f-876820b53c2f", + "version": 2, + "weekStart": "" +} \ No newline at end of file diff --git a/examples/flower-via-docker-compose/config/provisioning/dashboards/dashboards.yml b/examples/flower-via-docker-compose/config/provisioning/dashboards/dashboards.yml new file mode 100644 index 000000000000..e0d542f58f2b --- /dev/null +++ b/examples/flower-via-docker-compose/config/provisioning/dashboards/dashboards.yml @@ -0,0 +1,12 @@ +apiVersion: 1 + +providers: +- name: 'default' + orgId: 1 + folder: '' + type: file + disableDeletion: false + editable: true + updateIntervalSeconds: 5 # Optional: How often Grafana will scan for changed dashboards + options: + path: /etc/grafana/provisioning/dashboards diff --git a/examples/flower-via-docker-compose/config/provisioning/datasources/prometheus-datasource.yml b/examples/flower-via-docker-compose/config/provisioning/datasources/prometheus-datasource.yml new file mode 100644 index 000000000000..7c8ce00fdcdc --- /dev/null +++ b/examples/flower-via-docker-compose/config/provisioning/datasources/prometheus-datasource.yml @@ -0,0 +1,9 @@ +apiVersion: 1 + +datasources: +- name: Prometheus + type: prometheus + access: proxy + uid: db69454e-e558-479e-b4fc-80db52bf91da + url: http://host.docker.internal:9090 + isDefault: true diff --git a/examples/flower-via-docker-compose/helpers/generate_docker_compose.py b/examples/flower-via-docker-compose/helpers/generate_docker_compose.py new file mode 100644 index 000000000000..cde553a95e68 --- /dev/null +++ b/examples/flower-via-docker-compose/helpers/generate_docker_compose.py @@ -0,0 +1,147 @@ +import random +import argparse + +parser = argparse.ArgumentParser(description="Generated Docker Compose") +parser.add_argument( + "--total_clients", type=int, default=2, help="Total clients to spawn (default: 2)" +) +parser.add_argument( + "--num_rounds", type=int, default=100, help="Number of FL rounds (default: 100)" +) +parser.add_argument( + "--data_percentage", + type=float, + default=0.6, + help="Portion of client data to use (default: 0.6)", +) +parser.add_argument( + "--random", action="store_true", help="Randomize client configurations" +) + + +def create_docker_compose(args): + # cpus is used to set the number of CPUs available to the container as a fraction of the total number of CPUs on the host machine. + # mem_limit is used to set the memory limit for the container. + client_configs = [ + {"mem_limit": "3g", "batch_size": 32, "cpus": 4, "learning_rate": 0.001}, + {"mem_limit": "6g", "batch_size": 256, "cpus": 1, "learning_rate": 0.05}, + {"mem_limit": "4g", "batch_size": 64, "cpus": 3, "learning_rate": 0.02}, + {"mem_limit": "5g", "batch_size": 128, "cpus": 2.5, "learning_rate": 0.09}, + # Add or modify the configurations depending on your host machine + ] + + docker_compose_content = f""" +version: '3' +services: + prometheus: + image: prom/prometheus:latest + container_name: prometheus + ports: + - 9090:9090 + deploy: + restart_policy: + condition: on-failure + command: + - --config.file=/etc/prometheus/prometheus.yml + volumes: + - ./config/prometheus.yml:/etc/prometheus/prometheus.yml:ro + depends_on: + - cadvisor + + cadvisor: + image: gcr.io/cadvisor/cadvisor:v0.47.0 + container_name: cadvisor + privileged: true + deploy: + restart_policy: + condition: on-failure + ports: + - "8080:8080" + volumes: + - /:/rootfs:ro + - /var/run:/var/run:ro + - /sys:/sys:ro + - /var/lib/docker/:/var/lib/docker:ro + - /dev/disk/:/dev/disk:ro + - /var/run/docker.sock:/var/run/docker.sock + + grafana: + image: grafana/grafana:latest + container_name: grafana + ports: + - 3000:3000 + deploy: + restart_policy: + condition: on-failure + volumes: + - grafana-storage:/var/lib/grafana + - ./config/grafana.ini:/etc/grafana/grafana.ini + - ./config/provisioning/datasources:/etc/grafana/provisioning/datasources + - ./config/provisioning/dashboards:/etc/grafana/provisioning/dashboards + depends_on: + - prometheus + - cadvisor + command: + - --config=/etc/grafana/grafana.ini + + + server: + container_name: server + build: + context: . + dockerfile: Dockerfile + command: python server.py --number_of_rounds={args.num_rounds} + environment: + FLASK_RUN_PORT: 6000 + DOCKER_HOST_IP: host.docker.internal + volumes: + - .:/app + - /var/run/docker.sock:/var/run/docker.sock + ports: + - "6000:6000" + - "8265:8265" + - "8000:8000" + depends_on: + - prometheus + - grafana +""" + # Add client services + for i in range(1, args.total_clients + 1): + if args.random: + config = random.choice(client_configs) + else: + config = client_configs[(i - 1) % len(client_configs)] + docker_compose_content += f""" + client{i}: + container_name: client{i} + build: + context: . + dockerfile: Dockerfile + command: python client.py --server_address=server:8080 --data_percentage={args.data_percentage} --client_id={i} --total_clients={args.total_clients} --batch_size={config["batch_size"]} --learning_rate={config["learning_rate"]} + deploy: + resources: + limits: + cpus: "{(config['cpus'])}" + memory: "{config['mem_limit']}" + volumes: + - .:/app + - /var/run/docker.sock:/var/run/docker.sock + ports: + - "{6000 + i}:{6000 + i}" + depends_on: + - server + environment: + FLASK_RUN_PORT: {6000 + i} + container_name: client{i} + DOCKER_HOST_IP: host.docker.internal +""" + + docker_compose_content += "volumes:\n grafana-storage:\n" + + with open("docker-compose.yml", "w") as file: + file.write(docker_compose_content) + + +if __name__ == "__main__": + args = parser.parse_args() + create_docker_compose(args) diff --git a/examples/flower-via-docker-compose/helpers/load_data.py b/examples/flower-via-docker-compose/helpers/load_data.py new file mode 100644 index 000000000000..1f2784946868 --- /dev/null +++ b/examples/flower-via-docker-compose/helpers/load_data.py @@ -0,0 +1,37 @@ +import numpy as np +import tensorflow as tf +from flwr_datasets import FederatedDataset +import logging + +logging.basicConfig(level=logging.INFO) # Configure logging +logger = logging.getLogger(__name__) # Create logger for the module + + +def load_data(data_sampling_percentage=0.5, client_id=1, total_clients=2): + """Load federated dataset partition based on client ID. + + Args: + data_sampling_percentage (float): Percentage of the dataset to use for training. + client_id (int): Unique ID for the client. + total_clients (int): Total number of clients. + + Returns: + Tuple of arrays: `(x_train, y_train), (x_test, y_test)`. + """ + + # Download and partition dataset + fds = FederatedDataset(dataset="cifar10", partitioners={"train": total_clients}) + partition = fds.load_partition(client_id - 1, "train") + partition.set_format("numpy") + + # Divide data on each client: 80% train, 20% test + partition = partition.train_test_split(test_size=0.2) + x_train, y_train = partition["train"]["img"] / 255.0, partition["train"]["label"] + x_test, y_test = partition["test"]["img"] / 255.0, partition["test"]["label"] + + # Apply data sampling + num_samples = int(data_sampling_percentage * len(x_train)) + indices = np.random.choice(len(x_train), num_samples, replace=False) + x_train, y_train = x_train[indices], y_train[indices] + + return (x_train, y_train), (x_test, y_test) diff --git a/examples/flower-via-docker-compose/model/model.py b/examples/flower-via-docker-compose/model/model.py new file mode 100644 index 000000000000..ab26d089b858 --- /dev/null +++ b/examples/flower-via-docker-compose/model/model.py @@ -0,0 +1,18 @@ +import tensorflow as tf + + +# Class for the model. In this case, we are using the MobileNetV2 model from Keras +class Model: + def __init__(self, learning_rate): + self.learning_rate = learning_rate + self.loss_function = tf.keras.losses.SparseCategoricalCrossentropy() + self.model = tf.keras.applications.MobileNetV2( + (32, 32, 3), alpha=0.1, classes=10, weights=None + ) + self.optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate) + + def compile(self): + self.model.compile(self.optimizer, self.loss_function, metrics=["accuracy"]) + + def get_model(self): + return self.model diff --git a/examples/flower-via-docker-compose/requirements.txt b/examples/flower-via-docker-compose/requirements.txt new file mode 100644 index 000000000000..b93e5b1d9f2b --- /dev/null +++ b/examples/flower-via-docker-compose/requirements.txt @@ -0,0 +1,5 @@ +flwr==1.7.0 +tensorflow==2.13.1 +numpy==1.24.3 +prometheus_client == 0.19.0 +flwr_datasets[vision] == 0.0.2 diff --git a/examples/flower-via-docker-compose/server.py b/examples/flower-via-docker-compose/server.py new file mode 100644 index 000000000000..99d1a7ef7399 --- /dev/null +++ b/examples/flower-via-docker-compose/server.py @@ -0,0 +1,47 @@ +import argparse +import flwr as fl +import logging +from strategy.strategy import FedCustom +from prometheus_client import start_http_server, Gauge + +# Initialize Logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Define a gauge to track the global model accuracy +accuracy_gauge = Gauge("model_accuracy", "Current accuracy of the global model") + +# Define a gauge to track the global model loss +loss_gauge = Gauge("model_loss", "Current loss of the global model") + +# Parse command line arguments +parser = argparse.ArgumentParser(description="Flower Server") +parser.add_argument( + "--number_of_rounds", + type=int, + default=100, + help="Number of FL rounds (default: 100)", +) +args = parser.parse_args() + + +# Function to Start Federated Learning Server +def start_fl_server(strategy, rounds): + try: + fl.server.start_server( + server_address="0.0.0.0:8080", + config=fl.server.ServerConfig(num_rounds=rounds), + strategy=strategy, + ) + except Exception as e: + logger.error(f"FL Server error: {e}", exc_info=True) + + +# Main Function +if __name__ == "__main__": + # Start Prometheus Metrics Server + start_http_server(8000) + + # Initialize Strategy Instance and Start FL Server + strategy_instance = FedCustom(accuracy_gauge=accuracy_gauge, loss_gauge=loss_gauge) + start_fl_server(strategy=strategy_instance, rounds=args.number_of_rounds) diff --git a/examples/flower-via-docker-compose/strategy/strategy.py b/examples/flower-via-docker-compose/strategy/strategy.py new file mode 100644 index 000000000000..9471a99f037f --- /dev/null +++ b/examples/flower-via-docker-compose/strategy/strategy.py @@ -0,0 +1,60 @@ +from typing import Dict, List, Optional, Tuple, Union +from flwr.common import Scalar, EvaluateRes +from flwr.server.client_proxy import ClientProxy +from flwr.server.strategy.aggregate import aggregate, weighted_loss_avg +import flwr as fl +import logging +from prometheus_client import Gauge + +logging.basicConfig(level=logging.INFO) # Configure logging +logger = logging.getLogger(__name__) # Create logger for the module + + +class FedCustom(fl.server.strategy.FedAvg): + def __init__( + self, accuracy_gauge: Gauge = None, loss_gauge: Gauge = None, *args, **kwargs + ): + super().__init__(*args, **kwargs) + + self.accuracy_gauge = accuracy_gauge + self.loss_gauge = loss_gauge + + def __repr__(self) -> str: + return "FedCustom" + + def aggregate_evaluate( + self, + server_round: int, + results: List[Tuple[ClientProxy, EvaluateRes]], + failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> Tuple[Optional[float], Dict[str, Scalar]]: + """Aggregate evaluation losses and accuracy using weighted average.""" + + if not results: + return None, {} + + # Calculate weighted average for loss using the provided function + loss_aggregated = weighted_loss_avg( + [ + (evaluate_res.num_examples, evaluate_res.loss) + for _, evaluate_res in results + ] + ) + + # Calculate weighted average for accuracy + accuracies = [ + evaluate_res.metrics["accuracy"] * evaluate_res.num_examples + for _, evaluate_res in results + ] + examples = [evaluate_res.num_examples for _, evaluate_res in results] + accuracy_aggregated = ( + sum(accuracies) / sum(examples) if sum(examples) != 0 else 0 + ) + + # Update the Prometheus gauges with the latest aggregated accuracy and loss values + self.accuracy_gauge.set(accuracy_aggregated) + self.loss_gauge.set(loss_aggregated) + + metrics_aggregated = {"loss": loss_aggregated, "accuracy": accuracy_aggregated} + + return loss_aggregated, metrics_aggregated diff --git a/examples/ios/FLiOS/CoreMLProto/NeuralNetwork.pb.swift b/examples/ios/FLiOS/CoreMLProto/NeuralNetwork.pb.swift index 6075e689cb1e..3348ceae9e0e 100644 --- a/examples/ios/FLiOS/CoreMLProto/NeuralNetwork.pb.swift +++ b/examples/ios/FLiOS/CoreMLProto/NeuralNetwork.pb.swift @@ -4308,7 +4308,7 @@ struct CoreML_Specification_ConvolutionLayerParams { /// The output blob's shape is ``[batch, channelsOut, depthOut, heightOut, widthOut]``. /// /// Type of padding can be ``custom``, ``valid``, or ``same``. Padded values are all zeros. -/// Output spatial dimensions depend on the the type of padding. For details, refer to the +/// Output spatial dimensions depend on the type of padding. For details, refer to the /// descriptions of the ``PaddingType`` field of this ``Convolution3DLayerParams`` message. /// /// Example @@ -5219,7 +5219,7 @@ struct CoreML_Specification_Pooling3DLayerParams { /// SAME padding adds enough padding to each dimension such that the output /// has the same spatial dimensions as the input. Padding is added evenly to both /// sides of each dimension unless the total padding to add is odd, in which case the extra padding - /// is added to the back/bottom/right side of the respective dimension. For example, if the the + /// is added to the back/bottom/right side of the respective dimension. For example, if the /// total horizontal padding is 3, then there will be 1 padding on the left, and 2 padding on the right. enum Pooling3DPaddingType: SwiftProtobuf.Enum { typealias RawValue = Int @@ -9493,7 +9493,7 @@ struct CoreML_Specification_ConstantPaddingLayerParams { ///* /// Length of this repeated field must be twice the rank of the first input. - /// 2*i-th and (2*i+1)-th values represent the amount of padding to be applied to the the i-th input + /// 2*i-th and (2*i+1)-th values represent the amount of padding to be applied to the i-th input /// dimension, "before" and "after" the input values, respectively. var padAmounts: [UInt64] = [] diff --git a/examples/ios/FLiOS/CoreMLProto/NeuralNetwork.proto b/examples/ios/FLiOS/CoreMLProto/NeuralNetwork.proto index 6b2ebb1c8ba1..e80a3566c2c3 100644 --- a/examples/ios/FLiOS/CoreMLProto/NeuralNetwork.proto +++ b/examples/ios/FLiOS/CoreMLProto/NeuralNetwork.proto @@ -1554,7 +1554,7 @@ message ConvolutionLayerParams { * The output blob's shape is ``[batch, channelsOut, depthOut, heightOut, widthOut]``. * * Type of padding can be ``custom``, ``valid``, or ``same``. Padded values are all zeros. - * Output spatial dimensions depend on the the type of padding. For details, refer to the + * Output spatial dimensions depend on the type of padding. For details, refer to the * descriptions of the ``PaddingType`` field of this ``Convolution3DLayerParams`` message. * * Example @@ -2056,7 +2056,7 @@ message Pooling3DLayerParams { * SAME padding adds enough padding to each dimension such that the output * has the same spatial dimensions as the input. Padding is added evenly to both * sides of each dimension unless the total padding to add is odd, in which case the extra padding - * is added to the back/bottom/right side of the respective dimension. For example, if the the + * is added to the back/bottom/right side of the respective dimension. For example, if the * total horizontal padding is 3, then there will be 1 padding on the left, and 2 padding on the right. */ enum Pooling3DPaddingType { @@ -4779,7 +4779,7 @@ enum ScatterMode { * * The output has the same shape as the first input. * "indices" input must have rank less than or equal to the "updates" input and its shape - * must be a subset of the the shape of the "updates" input. + * must be a subset of the shape of the "updates" input. * * e.g: * @@ -5064,7 +5064,7 @@ message ConstantPaddingLayerParams { /** * Length of this repeated field must be twice the rank of the first input. - * 2*i-th and (2*i+1)-th values represent the amount of padding to be applied to the the i-th input + * 2*i-th and (2*i+1)-th values represent the amount of padding to be applied to the i-th input * dimension, "before" and "after" the input values, respectively. */ repeated uint64 padAmounts = 2; diff --git a/examples/ios/pyproject.toml b/examples/ios/pyproject.toml index c1bdbb815bd5..2e55b14cf761 100644 --- a/examples/ios/pyproject.toml +++ b/examples/ios/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" name = "flwr_ios_coreml" version = "0.1.0" description = "Example Server for Flower iOS/CoreML" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" diff --git a/examples/llm-flowertune/README.md b/examples/llm-flowertune/README.md new file mode 100644 index 000000000000..60e183d2a9c0 --- /dev/null +++ b/examples/llm-flowertune/README.md @@ -0,0 +1,139 @@ +# Federated Large Language Model (LLM) Fine-tuning with Flower + +Large language models (LLMs), which have been trained on vast amounts of publicly accessible data, have shown remarkable effectiveness in a wide range of areas. +However, despite the fact that more data typically leads to improved performance, there is a concerning prospect that the supply of high-quality public data will deplete within a few years. +Federated LLM training could unlock access to an endless pool of distributed private data by allowing multiple data owners to collaboratively train a shared model without the need to exchange raw data. + +This introductory example conducts federated instruction tuning with pretrained [LLama2](https://huggingface.co/openlm-research) models on [Alpaca-GPT4](https://huggingface.co/datasets/vicgalle/alpaca-gpt4) dataset. +We use [Flower Datasets](https://flower.dev/docs/datasets/) to download, partition and preprocess the dataset. +The fine-tuning is done using the [🤗PEFT](https://huggingface.co/docs/peft/en/index) library. +We use Flower's Simulation Engine to simulate the LLM fine-tuning process in federated way, +which allows users to perform the training on a single GPU. + +## Environments Setup + +Start by cloning the code example. We prepared a single-line command that you can copy into your shell which will checkout the example for you: + +```shell +git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/llm-flowertune . && rm -rf flower && cd llm-flowertune +``` + +This will create a new directory called `llm-flowertune` containing the following files: + +``` +-- README.md <- Your're reading this right now +-- main.py <- Start fed-LLM simulation +-- client.py <- Flower client constructor +-- model.py <- Model build +-- dataset.py <- Dataset and tokenizer build +-- utils.py <- Utility functions +-- test.py <- Test pre-trained model +-- app.py <- ServerApp/ClientApp for Flower-Next +-- conf/config.yaml <- Configuration file +-- requirements.txt <- Example dependencies +``` + +### Installing dependencies + +Project dependencies are defined in `requirements.txt`. Install them with: + +```shell +pip install -r requirements.txt +``` + +## Run LLM Fine-tuning + +With an activated Python environment, run the example with default config values. The config is in `conf/config.yaml` and is loaded automatically. + +```bash +# Run with default config +python main.py +``` + +This command will run FL simulations with a 4-bit [OpenLLaMA 7Bv2](https://huggingface.co/openlm-research/open_llama_7b_v2) model involving 2 clients per rounds for 100 FL rounds. You can override configuration parameters directly from the command line. Below are a few settings you might want to test: + +```bash +# Use OpenLLaMA-3B instead of 7B and 8-bits quantization +python main.py model.name="openlm-research/open_llama_3b_v2" model.quantization=8 + +# Run for 50 rounds but increasing the fraction of clients that participate per round to 25% +python main.py num_rounds=50 fraction_fit.fraction_fit=0.25 +``` + +## Expected Results + +![](_static/train_loss_smooth.png) + +As expected, LLama2-7B model works better than its 3B version with lower training loss. With the hyperparameters tested, the 8-bit model seems to deliver lower training loss for the smaller 3B model compared to its 4-bit version. + +You can run all 8 experiments with a single command as: + +```bash +python main.py --multirun model.name="openlm-research/open_llama_7b_v2","openlm-research/open_llama_3b_v2" model.quantization=8,4 strategy.fraction_fit=0.1,0.2 +``` + +## VRAM Consumption + +| Models | 7-billion (8-bit) | 7-billion (4-bit) | 3-billion (8-bit) | 3-billion (4-bit) | +| :----: | :---------------: | :---------------: | :---------------: | :---------------: | +| VRAM | ~22.00 GB | ~16.50 GB | ~13.50 GB | ~10.60 GB | + +We make use of the [bitsandbytes](https://huggingface.co/docs/bitsandbytes/main/en/index) library in conjunction with [PEFT](https://huggingface.co/docs/peft/en/index) to derive LLMs that can be fine-tuned efficiently. +The above table shows the VRAM consumption per client for the different models considered in this example. +You can adjust the CPU/GPU resources you assign to each of the clients based on your device. +For example, it is easy to train 2 concurrent clients on each GPU (24 GB VRAM) if you choose 3-billion (4-bit) model. + +```bash +# This will assign 50% of the GPU's VRAM to each client. +python main.py model.name="openlm-research/open_llama_3b_v2" model.quantization=4 client_resources.num_gpus=0.5 +``` + +## Test with your Questions + +We provide a script to test your trained model by passing your specified questions. For example: + +```bash +python test.py --peft-path=/path/to/trained-model-dir/ \ + --question="What is the ideal 1-day plan in London?" +``` + +An answer generated from federated trained 7-billion (8-bit) LLama2 model: + +``` +Great choice. +London has so much to offer, and you can really soak up all the sights and sounds in just a single day. +Here's a suggested itinerary for you. +Start your day off with a hearty breakfast at an authentic British diner. +Then head to the iconic Big Ben and the Houses of Parliament to learn about the history of the city. +Next, make your way to Westminster Abbey to see the many historical monuments and memorials. +From there, cross the river Thames to the Tower of London, which is home to the Crown Jewels of England and Scotland. +Finally, end your day with a relaxing visit to the London Eye, the tallest Ferris wheel in Europe, for a beautiful view of the city. +``` + +The [`Vicuna`](https://huggingface.co/lmsys/vicuna-13b-v1.1) template we used in this example is for a chat assistant. +The generated answer is expected to be a multi-turn conversations. Feel free to try more interesting questions! + +## Run with Flower Next (preview) + +We conduct a 2-client setting to demonstrate how to run federated LLM fine-tuning with Flower Next. +Please follow the steps below: + +1. Start the long-running Flower server (SuperLink) + ```bash + flower-superlink --insecure + ``` +2. Start the long-running Flower client (SuperNode) + ```bash + # In a new terminal window, start the first long-running Flower client: + flower-client-app app:client1 --insecure + ``` + ```bash + # In another new terminal window, start the second long-running Flower client: + flower-client-app app:client2 --insecure + ``` +3. Run the Flower App + ```bash + # With both the long-running server (SuperLink) and two clients (SuperNode) up and running, + # we can now run the actual Flower App: + flower-server-app app:server --insecure + ``` diff --git a/examples/llm-flowertune/_static/train_loss_smooth.png b/examples/llm-flowertune/_static/train_loss_smooth.png new file mode 100644 index 000000000000..02034e6e9eb5 Binary files /dev/null and b/examples/llm-flowertune/_static/train_loss_smooth.png differ diff --git a/examples/llm-flowertune/app.py b/examples/llm-flowertune/app.py new file mode 100644 index 000000000000..e04ad8715de6 --- /dev/null +++ b/examples/llm-flowertune/app.py @@ -0,0 +1,86 @@ +import os +import warnings +from hydra import compose, initialize + +import flwr as fl +from flwr_datasets import FederatedDataset + +from dataset import get_tokenizer_and_data_collator_and_propt_formatting +from client import gen_client_fn +from utils import get_on_fit_config, fit_weighted_average + + +warnings.filterwarnings("ignore", category=UserWarning) + +NUM_ROUNDS = 100 +save_path = "./results/" + +with initialize(config_path="conf"): + cfg = compose(config_name="config") + +# Reset the number of number +cfg.num_rounds = NUM_ROUNDS +cfg.train.num_rounds = NUM_ROUNDS + +# Create output directory +if not os.path.exists(save_path): + os.mkdir(save_path) + +# Partition dataset and get dataloaders +# We set the number of partitions to 20 for fast processing. +fds = FederatedDataset( + dataset=cfg.dataset.name, partitioners={"train": cfg.num_clients} +) +( + tokenizer, + data_collator, + formatting_prompts_func, +) = get_tokenizer_and_data_collator_and_propt_formatting(cfg.model.name) + + +# ClientApp for client #1 (Flower Next) +client1 = fl.client.ClientApp( + client_fn=gen_client_fn( + fds, + tokenizer, + formatting_prompts_func, + data_collator, + cfg.model, + cfg.train, + save_path, + partition_id=0, + api=True, + ), +) + + +# ClientApp for client #2 (Flower Next) +client2 = fl.client.ClientApp( + client_fn=gen_client_fn( + fds, + tokenizer, + formatting_prompts_func, + data_collator, + cfg.model, + cfg.train, + save_path, + partition_id=1, + api=True, + ), +) + + +# Instantiate strategy. +strategy = fl.server.strategy.FedAvg( + min_available_clients=2, # Simulate a 2-client setting + fraction_fit=1.0, + fraction_evaluate=0.0, # no client evaluation + on_fit_config_fn=get_on_fit_config(), + fit_metrics_aggregation_fn=fit_weighted_average, +) + +# ServerApp for Flower-Next +server = fl.server.ServerApp( + config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS), + strategy=strategy, +) diff --git a/examples/llm-flowertune/client.py b/examples/llm-flowertune/client.py new file mode 100644 index 000000000000..28b324ba5bf1 --- /dev/null +++ b/examples/llm-flowertune/client.py @@ -0,0 +1,129 @@ +from collections import OrderedDict +from typing import Callable, Dict, Tuple + +import flwr as fl +import torch +from flwr.common.typing import NDArrays, Scalar +from omegaconf import DictConfig +from trl import SFTTrainer +from transformers import TrainingArguments +from peft import get_peft_model_state_dict, set_peft_model_state_dict + +from models import get_model, cosine_annealing + + +# pylint: disable=too-many-arguments +class FlowerClient( + fl.client.NumPyClient +): # pylint: disable=too-many-instance-attributes + """Standard Flower client for CNN training.""" + + def __init__( + self, + model_cfg: DictConfig, + train_cfg: DictConfig, + trainset, + tokenizer, + formatting_prompts_func, + data_collator, + save_path, + ): # pylint: disable=too-many-arguments + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self.train_cfg = train_cfg + self.training_argumnets = TrainingArguments(**train_cfg.training_arguments) + self.tokenizer = tokenizer + self.formatting_prompts_func = formatting_prompts_func + self.data_collator = data_collator + self.save_path = save_path + + # instantiate model + self.model = get_model(model_cfg) + + self.trainset = trainset + + def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays: + """Return the parameters of the current net.""" + + state_dict = get_peft_model_state_dict(self.model) + return [val.cpu().numpy() for _, val in state_dict.items()] + + def fit( + self, parameters: NDArrays, config: Dict[str, Scalar] + ) -> Tuple[NDArrays, int, Dict]: + """Implement distributed fit function for a given client.""" + set_parameters(self.model, parameters) + + new_lr = cosine_annealing( + int(config["current_round"]), + self.train_cfg.num_rounds, + self.train_cfg.learning_rate_max, + self.train_cfg.learning_rate_min, + ) + + self.training_argumnets.learning_rate = new_lr + self.training_argumnets.output_dir = self.save_path + + # Construct trainer + trainer = SFTTrainer( + model=self.model, + tokenizer=self.tokenizer, + args=self.training_argumnets, + max_seq_length=self.train_cfg.seq_length, + train_dataset=self.trainset, + formatting_func=self.formatting_prompts_func, + data_collator=self.data_collator, + ) + + # Do local training + results = trainer.train() + + return ( + self.get_parameters({}), + len(self.trainset), + {"train_loss": results.training_loss}, + ) + + +def set_parameters(model, parameters: NDArrays) -> None: + """Change the parameters of the model using the given ones.""" + peft_state_dict_keys = get_peft_model_state_dict(model).keys() + params_dict = zip(peft_state_dict_keys, parameters) + state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict}) + set_peft_model_state_dict(model, state_dict) + + +def gen_client_fn( + fds, + tokenizer, + formatting_prompts_func, + data_collator, + model_cfg: DictConfig, + train_cfg: DictConfig, + save_path: str, + partition_id: int = 0, + api: bool = False, +) -> Callable[[str], FlowerClient]: # pylint: disable=too-many-arguments + """Generate the client function that creates the Flower Clients.""" + + def client_fn(cid: str) -> FlowerClient: + """Create a Flower client representing a single organization.""" + + # Let's get the partition corresponding to the i-th client + client_trainset = ( + fds.load_partition(partition_id, "train") + if api + else fds.load_partition(int(cid), "train") + ) + client_trainset = client_trainset.rename_column("output", "response") + + return FlowerClient( + model_cfg, + train_cfg, + client_trainset, + tokenizer, + formatting_prompts_func, + data_collator, + save_path, + ).to_client() + + return client_fn diff --git a/examples/llm-flowertune/conf/config.yaml b/examples/llm-flowertune/conf/config.yaml new file mode 100644 index 000000000000..0b769d351479 --- /dev/null +++ b/examples/llm-flowertune/conf/config.yaml @@ -0,0 +1,45 @@ +# Federated Instruction Tuning on General Dataset +--- + +num_clients: 20 # total number of clients +num_rounds: 100 + +dataset: + name: "vicgalle/alpaca-gpt4" + +model: + name: "openlm-research/open_llama_7b_v2" + quantization: 4 # 8 or 4 if you want to do quantization with BitsAndBytes + gradient_checkpointing: True + lora: + peft_lora_r: 32 + peft_lora_alpha: 64 + +train: + num_rounds: ${num_rounds} + save_every_round: 5 + learning_rate_max: 5e-5 + learning_rate_min: 1e-6 + seq_length: 512 + training_arguments: + output_dir: null # to be set by hydra + learning_rate: null # to be set by the client + per_device_train_batch_size: 16 + gradient_accumulation_steps: 1 + logging_steps: 10 + num_train_epochs: 3 + max_steps: 10 + report_to: null + save_steps: 1000 + save_total_limit: 10 + gradient_checkpointing: ${model.gradient_checkpointing} + lr_scheduler_type: "constant" + +strategy: + _target_: flwr.server.strategy.FedAvg + fraction_fit: 0.1 # sample 10% of clients (i.e. 2 per round) + fraction_evaluate: 0.0 # no client evaluation + +client_resources: + num_cpus: 8 + num_gpus: 1.0 diff --git a/examples/llm-flowertune/dataset.py b/examples/llm-flowertune/dataset.py new file mode 100644 index 000000000000..571be31f7fba --- /dev/null +++ b/examples/llm-flowertune/dataset.py @@ -0,0 +1,29 @@ +from transformers import AutoTokenizer +from trl import DataCollatorForCompletionOnlyLM + + +def formatting_prompts_func(example): + output_texts = [] + # Constructing a standard Alpaca (https://github.com/tatsu-lab/stanford_alpaca#data-release) prompt + mssg = "Below is an instruction that describes a task. Write a response that appropriately completes the request." + for i in range(len(example["instruction"])): + text = f"{mssg}\n### Instruction:\n{example['instruction'][i]}\n### Response: {example['response'][i]}" + output_texts.append(text) + return output_texts + + +def get_tokenizer_and_data_collator_and_propt_formatting(model_name: str): + # From: https://huggingface.co/docs/trl/en/sft_trainer + tokenizer = AutoTokenizer.from_pretrained( + model_name, use_fast=True, padding_side="right" + ) + tokenizer.pad_token = tokenizer.eos_token + response_template_with_context = "\n### Response:" # alpaca response tag + response_template_ids = tokenizer.encode( + response_template_with_context, add_special_tokens=False + )[2:] + data_collator = DataCollatorForCompletionOnlyLM( + response_template_ids, tokenizer=tokenizer + ) + + return tokenizer, data_collator, formatting_prompts_func diff --git a/examples/llm-flowertune/main.py b/examples/llm-flowertune/main.py new file mode 100644 index 000000000000..2d03e9cbcae5 --- /dev/null +++ b/examples/llm-flowertune/main.py @@ -0,0 +1,94 @@ +import warnings +import pickle + +import flwr as fl +from flwr_datasets import FederatedDataset + +import hydra +from hydra.core.hydra_config import HydraConfig +from hydra.utils import instantiate +from omegaconf import DictConfig, OmegaConf + +from dataset import get_tokenizer_and_data_collator_and_propt_formatting +from utils import get_on_fit_config, fit_weighted_average, get_evaluate_fn +from client import gen_client_fn + + +warnings.filterwarnings("ignore", category=UserWarning) + + +@hydra.main(config_path="conf", config_name="config", version_base=None) +def main(cfg: DictConfig) -> None: + """Run federated LLM fine-tuning. + + Parameters + ---------- + cfg : DictConfig + An omegaconf object that stores the hydra config. + """ + # Print config structured as YAML + print(OmegaConf.to_yaml(cfg)) + + # Partition dataset and get dataloaders + fds = FederatedDataset( + dataset=cfg.dataset.name, partitioners={"train": cfg.num_clients} + ) + ( + tokenizer, + data_collator, + formatting_prompts_func, + ) = get_tokenizer_and_data_collator_and_propt_formatting( + cfg.model.name, + ) + + # Hydra automatically creates an output directory + # Let's retrieve it and save some results there + save_path = HydraConfig.get().runtime.output_dir + + # Prepare function that will be used to spawn each client + client_fn = gen_client_fn( + fds, + tokenizer, + formatting_prompts_func, + data_collator, + cfg.model, + cfg.train, + save_path, + ) + + # Instantiate strategy according to config. Here we pass other arguments + # that are only defined at run time. + strategy = instantiate( + cfg.strategy, + on_fit_config_fn=get_on_fit_config(), + fit_metrics_aggregation_fn=fit_weighted_average, + evaluate_fn=get_evaluate_fn( + cfg.model, cfg.train.save_every_round, cfg.num_rounds, save_path + ), + ) + + # Start simulation + history = fl.simulation.start_simulation( + client_fn=client_fn, + num_clients=cfg.num_clients, + config=fl.server.ServerConfig(num_rounds=cfg.num_rounds), + client_resources={ + "num_cpus": cfg.client_resources.num_cpus, + "num_gpus": cfg.client_resources.num_gpus, + }, + strategy=strategy, + ) + + # Experiment completed. Now we save the results and + # generate plots using the `history` + print("................") + print(history) + + # Save results as a Python pickle using a file_path + # the directory created by Hydra for each run + with open(f"{save_path}/results.pkl", "wb") as f: + pickle.dump(history, f) + + +if __name__ == "__main__": + main() diff --git a/examples/llm-flowertune/models.py b/examples/llm-flowertune/models.py new file mode 100644 index 000000000000..78eef75d10d2 --- /dev/null +++ b/examples/llm-flowertune/models.py @@ -0,0 +1,56 @@ +import torch +from omegaconf import DictConfig +from transformers import AutoModelForCausalLM +from transformers import BitsAndBytesConfig +from peft import get_peft_model, LoraConfig +from peft.utils import prepare_model_for_kbit_training + +import math + + +def cosine_annealing( + current_round: int, + total_round: int, + lrate_max: float = 0.001, + lrate_min: float = 0.0, +) -> float: + """Implement cosine annealing learning rate schedule.""" + + cos_inner = math.pi * current_round / total_round + return lrate_min + 0.5 * (lrate_max - lrate_min) * (1 + math.cos(cos_inner)) + + +def get_model(model_cfg: DictConfig): + """Load model with appropriate quantization config and other optimizations. + + Please refer to this example for `peft + BitsAndBytes`: + https://github.com/huggingface/peft/blob/main/examples/fp4_finetuning/finetune_fp4_opt_bnb_peft.py + """ + + if model_cfg.quantization == 4: + quantization_config = BitsAndBytesConfig(load_in_4bit=True) + elif model_cfg.quantization == 8: + quantization_config = BitsAndBytesConfig(load_in_8bit=True) + else: + raise ValueError( + f"Use 4-bit or 8-bit quantization. You passed: {model_cfg.quantization}/" + ) + + model = AutoModelForCausalLM.from_pretrained( + model_cfg.name, + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + ) + + model = prepare_model_for_kbit_training( + model, use_gradient_checkpointing=model_cfg.gradient_checkpointing + ) + + peft_config = LoraConfig( + r=model_cfg.lora.peft_lora_r, + lora_alpha=model_cfg.lora.peft_lora_alpha, + lora_dropout=0.075, + task_type="CAUSAL_LM", + ) + + return get_peft_model(model, peft_config) diff --git a/examples/llm-flowertune/requirements.txt b/examples/llm-flowertune/requirements.txt new file mode 100644 index 000000000000..c7ff57b403f7 --- /dev/null +++ b/examples/llm-flowertune/requirements.txt @@ -0,0 +1,8 @@ +flwr-nightly[rest,simulation] +flwr_datasets==0.0.2 +hydra-core==1.3.2 +trl==0.7.2 +bitsandbytes==0.41.3 +scipy==1.11.2 +peft==0.4.0 +fschat[model_worker,webui]==0.2.35 diff --git a/examples/llm-flowertune/test.py b/examples/llm-flowertune/test.py new file mode 100644 index 000000000000..652bb9aafcf5 --- /dev/null +++ b/examples/llm-flowertune/test.py @@ -0,0 +1,73 @@ +# This python file is adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/llm_judge/gen_model_answer.py + +import argparse +import torch +from peft import AutoPeftModelForCausalLM +from transformers import AutoTokenizer +from fastchat.conversation import get_conv_template + + +parser = argparse.ArgumentParser() +parser.add_argument("--peft-path", type=str, default=None) +parser.add_argument("--question", type=str, default="How are you") +parser.add_argument("--template", type=str, default="vicuna_v1.1") +args = parser.parse_args() + +# Load model and tokenizer +model = AutoPeftModelForCausalLM.from_pretrained( + args.peft_path, torch_dtype=torch.float16 +).to("cuda") +base_model = model.peft_config["default"].base_model_name_or_path +tokenizer = AutoTokenizer.from_pretrained(base_model) + +# Generate answers +temperature = 0.7 +choices = [] +conv = get_conv_template(args.template) + +conv.append_message(conv.roles[0], args.question) +conv.append_message(conv.roles[1], None) +prompt = conv.get_prompt() +input_ids = tokenizer([prompt]).input_ids + +output_ids = model.generate( + input_ids=torch.as_tensor(input_ids).cuda(), + do_sample=True, + temperature=temperature, + max_new_tokens=1024, +) + +output_ids = ( + output_ids[0] + if model.config.is_encoder_decoder + else output_ids[0][len(input_ids[0]) :] +) + +# Be consistent with the template's stop_token_ids +if conv.stop_token_ids: + stop_token_ids_index = [ + i for i, id in enumerate(output_ids) if id in conv.stop_token_ids + ] + if len(stop_token_ids_index) > 0: + output_ids = output_ids[: stop_token_ids_index[0]] + +output = tokenizer.decode( + output_ids, + spaces_between_special_tokens=False, +) + +if conv.stop_str and output.find(conv.stop_str) > 0: + output = output[: output.find(conv.stop_str)] + +for special_token in tokenizer.special_tokens_map.values(): + if isinstance(special_token, list): + for special_tok in special_token: + output = output.replace(special_tok, "") + else: + output = output.replace(special_token, "") +output = output.strip() + +conv.update_last_message(output) + +print(f">> prompt: {prompt}") +print(f">> Generated: {output}") diff --git a/examples/llm-flowertune/utils.py b/examples/llm-flowertune/utils.py new file mode 100644 index 000000000000..bbb607810537 --- /dev/null +++ b/examples/llm-flowertune/utils.py @@ -0,0 +1,43 @@ +from client import set_parameters +from models import get_model + + +# Get function that will be executed by the strategy's evaluate() method +# Here we use it to save global model checkpoints +def get_evaluate_fn(model_cfg, save_every_round, total_round, save_path): + """Return an evaluation function for saving global model.""" + + def evaluate(server_round: int, parameters, config): + # Save model + if server_round != 0 and ( + server_round == total_round or server_round % save_every_round == 0 + ): + # Init model + model = get_model(model_cfg) + set_parameters(model, parameters) + + model.save_pretrained(f"{save_path}/peft_{server_round}") + + return 0.0, {} + + return evaluate + + +# Get a function that will be used to construct the config that the client's +# fit() method will receive +def get_on_fit_config(): + def fit_config_fn(server_round: int): + fit_config = {"current_round": server_round} + return fit_config + + return fit_config_fn + + +def fit_weighted_average(metrics): + """Aggregation function for (federated) evaluation metrics.""" + # Multiply accuracy of each client by number of examples used + losses = [num_examples * m["train_loss"] for num_examples, m in metrics] + examples = [num_examples for num_examples, _ in metrics] + + # Aggregate and return custom metric (weighted average) + return {"train_loss": sum(losses) / sum(examples)} diff --git a/examples/mt-pytorch-callable/README.md b/examples/mt-pytorch-callable/README.md deleted file mode 100644 index 65ef000c26f2..000000000000 --- a/examples/mt-pytorch-callable/README.md +++ /dev/null @@ -1,49 +0,0 @@ -# Deploy 🧪 - -🧪 = this page covers experimental features that might change in future versions of Flower - -This how-to guide describes the deployment of a long-running Flower server. - -## Preconditions - -Let's assume the following project structure: - -```bash -$ tree . -. -└── client.py -├── driver.py -├── requirements.txt -``` - -## Install dependencies - -```bash -pip install -r requirements.txt -``` - -## Start the long-running Flower server - -```bash -flower-server --insecure -``` - -## Start the long-running Flower client - -In a new terminal window, start the first long-running Flower client: - -```bash -flower-client --callable client:flower -``` - -In yet another new terminal window, start the second long-running Flower client: - -```bash -flower-client --callable client:flower -``` - -## Start the Driver script - -```bash -python driver.py -``` diff --git a/examples/mt-pytorch-callable/client.py b/examples/mt-pytorch-callable/client.py deleted file mode 100644 index 6f9747784ae0..000000000000 --- a/examples/mt-pytorch-callable/client.py +++ /dev/null @@ -1,123 +0,0 @@ -import warnings -from collections import OrderedDict - -import flwr as fl -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.utils.data import DataLoader -from torchvision.datasets import CIFAR10 -from torchvision.transforms import Compose, Normalize, ToTensor -from tqdm import tqdm - - -# ############################################################################# -# 1. Regular PyTorch pipeline: nn.Module, train, test, and DataLoader -# ############################################################################# - -warnings.filterwarnings("ignore", category=UserWarning) -DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - - -class Net(nn.Module): - """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" - - def __init__(self) -> None: - super(Net, self).__init__() - self.conv1 = nn.Conv2d(3, 6, 5) - self.pool = nn.MaxPool2d(2, 2) - self.conv2 = nn.Conv2d(6, 16, 5) - self.fc1 = nn.Linear(16 * 5 * 5, 120) - self.fc2 = nn.Linear(120, 84) - self.fc3 = nn.Linear(84, 10) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.pool(F.relu(self.conv1(x))) - x = self.pool(F.relu(self.conv2(x))) - x = x.view(-1, 16 * 5 * 5) - x = F.relu(self.fc1(x)) - x = F.relu(self.fc2(x)) - return self.fc3(x) - - -def train(net, trainloader, epochs): - """Train the model on the training set.""" - criterion = torch.nn.CrossEntropyLoss() - optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9) - for _ in range(epochs): - for images, labels in tqdm(trainloader): - optimizer.zero_grad() - criterion(net(images.to(DEVICE)), labels.to(DEVICE)).backward() - optimizer.step() - - -def test(net, testloader): - """Validate the model on the test set.""" - criterion = torch.nn.CrossEntropyLoss() - correct, loss = 0, 0.0 - with torch.no_grad(): - for images, labels in tqdm(testloader): - outputs = net(images.to(DEVICE)) - labels = labels.to(DEVICE) - loss += criterion(outputs, labels).item() - correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() - accuracy = correct / len(testloader.dataset) - return loss, accuracy - - -def load_data(): - """Load CIFAR-10 (training and test set).""" - trf = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) - trainset = CIFAR10("./data", train=True, download=True, transform=trf) - testset = CIFAR10("./data", train=False, download=True, transform=trf) - return DataLoader(trainset, batch_size=32, shuffle=True), DataLoader(testset) - - -# ############################################################################# -# 2. Federation of the pipeline with Flower -# ############################################################################# - -# Load model and data (simple CNN, CIFAR-10) -net = Net().to(DEVICE) -trainloader, testloader = load_data() - - -# Define Flower client -class FlowerClient(fl.client.NumPyClient): - def get_parameters(self, config): - return [val.cpu().numpy() for _, val in net.state_dict().items()] - - def set_parameters(self, parameters): - params_dict = zip(net.state_dict().keys(), parameters) - state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) - net.load_state_dict(state_dict, strict=True) - - def fit(self, parameters, config): - self.set_parameters(parameters) - train(net, trainloader, epochs=1) - return self.get_parameters(config={}), len(trainloader.dataset), {} - - def evaluate(self, parameters, config): - self.set_parameters(parameters) - loss, accuracy = test(net, testloader) - return loss, len(testloader.dataset), {"accuracy": accuracy} - - -def client_fn(cid: str): - """.""" - return FlowerClient().to_client() - - -# To run this: `flower-client --callable client:flower` -flower = fl.flower.Flower( - client_fn=client_fn, -) - - -if __name__ == "__main__": - # Start Flower client - fl.client.start_client( - server_address="0.0.0.0:9092", - client=FlowerClient().to_client(), - transport="grpc-rere", - ) diff --git a/examples/mt-pytorch-callable/driver.py b/examples/mt-pytorch-callable/driver.py deleted file mode 100644 index 1248672b6813..000000000000 --- a/examples/mt-pytorch-callable/driver.py +++ /dev/null @@ -1,25 +0,0 @@ -from typing import List, Tuple - -import flwr as fl -from flwr.common import Metrics - - -# Define metric aggregation function -def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: - # Multiply accuracy of each client by number of examples used - accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics] - examples = [num_examples for num_examples, _ in metrics] - - # Aggregate and return custom metric (weighted average) - return {"accuracy": sum(accuracies) / sum(examples)} - - -# Define strategy -strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=weighted_average) - -# Start Flower driver -fl.driver.start_driver( - server_address="0.0.0.0:9091", - config=fl.server.ServerConfig(num_rounds=3), - strategy=strategy, -) diff --git a/examples/mt-pytorch-callable/requirements.txt b/examples/mt-pytorch-callable/requirements.txt deleted file mode 100644 index 797ca6db6244..000000000000 --- a/examples/mt-pytorch-callable/requirements.txt +++ /dev/null @@ -1,4 +0,0 @@ -flwr>=1.0, <2.0 -torch==1.13.1 -torchvision==0.14.1 -tqdm==4.65.0 diff --git a/examples/mt-pytorch/README.md b/examples/mt-pytorch/README.md deleted file mode 100644 index ef9516314e26..000000000000 --- a/examples/mt-pytorch/README.md +++ /dev/null @@ -1,73 +0,0 @@ -# Multi-Tenant Federated Learning with Flower and PyTorch - -This example contains experimental code. Please consult the regular PyTorch code examples ([quickstart](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch), [advanced](https://github.com/adap/flower/tree/main/examples/advanced-pytorch)) to learn how to use Flower with PyTorch. - -## Setup - -```bash -./dev/venv-reset.sh -``` - -## Run with Driver API - -Terminal 1: start Flower server - -```bash -flower-server -``` - -Terminal 2+3: start two Flower client nodes - -```bash -python client.py -``` - -Terminal 4: start Driver script - -Using: - -```bash -python start_driver.py -``` - -Or, alternatively: - -```bash -python driver.py -``` - -## Run in legacy mode - -Terminal 1: start Flower server - -```bash -python server.py -``` - -Terminal 2+3: start two clients - -```bash -python client.py -``` - -## Run with Driver API (REST transport layer) - -Terminal 1: start Flower server and enable REST transport layer - -```bash -flower-server --rest -``` - -Terminal 2: start Driver script - -```bash -python driver.py -``` - -Open file `client.py` adjust `server_address` and `transport`. - -Terminal 3+4: start two Flower client nodes - -```bash -python client.py -``` diff --git a/examples/mt-pytorch/client.py b/examples/mt-pytorch/client.py deleted file mode 100644 index 23cc736fd62b..000000000000 --- a/examples/mt-pytorch/client.py +++ /dev/null @@ -1,42 +0,0 @@ -from typing import Dict -import flwr as fl -from flwr.common import NDArrays, Scalar - -from task import ( - Net, - DEVICE, - load_data, - get_parameters, - set_parameters, - train, - test, -) - - -# Load model and data (simple CNN, CIFAR-10) -net = Net().to(DEVICE) -trainloader, testloader = load_data() - - -# Define Flower client -class FlowerClient(fl.client.NumPyClient): - def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays: - return get_parameters(net) - - def fit(self, parameters, config): - set_parameters(net, parameters) - results = train(net, trainloader, testloader, epochs=1, device=DEVICE) - return get_parameters(net), len(trainloader.dataset), results - - def evaluate(self, parameters, config): - set_parameters(net, parameters) - loss, accuracy = test(net, testloader) - return loss, len(testloader.dataset), {"accuracy": accuracy} - - -# Start Flower client -fl.client.start_numpy_client( - server_address="0.0.0.0:9092", # "0.0.0.0:9093" for REST - client=FlowerClient(), - transport="grpc-rere", # "rest" for REST -) diff --git a/examples/mt-pytorch/driver.py b/examples/mt-pytorch/driver.py deleted file mode 100644 index fed760f021af..000000000000 --- a/examples/mt-pytorch/driver.py +++ /dev/null @@ -1,226 +0,0 @@ -from typing import List, Tuple -import random -import time - -from flwr.driver import GrpcDriver -from flwr.common import ( - ServerMessage, - FitIns, - ndarrays_to_parameters, - serde, - parameters_to_ndarrays, - ClientMessage, - NDArrays, - Code, -) -from flwr.proto import driver_pb2, task_pb2, node_pb2, transport_pb2 -from flwr.server.strategy.aggregate import aggregate -from flwr.common import Metrics -from flwr.server import History -from flwr.common import serde -from task import Net, get_parameters, set_parameters - - -# Define metric aggregation function -def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: - examples = [num_examples for num_examples, _ in metrics] - - # Multiply accuracy of each client by number of examples used - train_losses = [num_examples * m["train_loss"] for num_examples, m in metrics] - train_accuracies = [ - num_examples * m["train_accuracy"] for num_examples, m in metrics - ] - val_losses = [num_examples * m["val_loss"] for num_examples, m in metrics] - val_accuracies = [num_examples * m["val_accuracy"] for num_examples, m in metrics] - - # Aggregate and return custom metric (weighted average) - return { - "train_loss": sum(train_losses) / sum(examples), - "train_accuracy": sum(train_accuracies) / sum(examples), - "val_loss": sum(val_losses) / sum(examples), - "val_accuracy": sum(val_accuracies) / sum(examples), - } - - -# -------------------------------------------------------------------------- Driver SDK -driver = GrpcDriver(driver_service_address="0.0.0.0:9091", certificates=None) -# -------------------------------------------------------------------------- Driver SDK - -anonymous_client_nodes = False -num_client_nodes_per_round = 2 -sleep_time = 1 -num_rounds = 3 -parameters = ndarrays_to_parameters(get_parameters(net=Net())) - -# -------------------------------------------------------------------------- Driver SDK -driver.connect() -create_workload_res: driver_pb2.CreateWorkloadResponse = driver.create_workload( - req=driver_pb2.CreateWorkloadRequest() -) -# -------------------------------------------------------------------------- Driver SDK - -workload_id = create_workload_res.workload_id -print(f"Created workload id {workload_id}") - -history = History() -for server_round in range(num_rounds): - print(f"Commencing server round {server_round + 1}") - - # List of sampled node IDs in this round - sampled_nodes: List[node_pb2.Node] = [] - - # Sample node ids - if anonymous_client_nodes: - # If we're working with anonymous clients, we don't know their identities, and - # we don't know how many of them we have. We, therefore, have to assume that - # enough anonymous client nodes are available or become available over time. - # - # To schedule a TaskIns for an anonymous client node, we set the node_id to 0 - # (and `anonymous` to True) - # Here, we create an array with only zeros in it: - sampled_node_ids: List[int] = [0] * num_client_nodes_per_round - sampled_nodes = [ - node_pb2.Node(node_id=node_id, anonymous=False) - for node_id in sampled_node_ids - ] - else: - # If our client nodes have identiy (i.e., they are not anonymous), we can get - # those IDs from the Driver API using `get_nodes`. If enough clients are - # available via the Driver API, we can select a subset by taking a random - # sample. - # - # The Driver API might not immediately return enough client node IDs, so we - # loop and wait until enough client nodes are available. - while True: - # Get a list of node ID's from the server - get_nodes_req = driver_pb2.GetNodesRequest(workload_id=workload_id) - - # ---------------------------------------------------------------------- Driver SDK - get_nodes_res: driver_pb2.GetNodesResponse = driver.get_nodes( - req=get_nodes_req - ) - # ---------------------------------------------------------------------- Driver SDK - - all_nodes: List[node_pb2.Node] = get_nodes_res.nodes - print(f"Got {len(all_nodes)} client nodes") - - if len(all_nodes) >= num_client_nodes_per_round: - # Sample client nodes - sampled_nodes = random.sample(all_nodes, num_client_nodes_per_round) - break - - time.sleep(3) - - # Log sampled node IDs - print(f"Sampled {len(sampled_nodes)} node IDs: {sampled_nodes}") - time.sleep(sleep_time) - - # Schedule a task for all sampled nodes - fit_ins: FitIns = FitIns(parameters=parameters, config={}) - server_message_proto: transport_pb2.ServerMessage = serde.server_message_to_proto( - server_message=ServerMessage(fit_ins=fit_ins) - ) - task_ins_list: List[task_pb2.TaskIns] = [] - for sampled_node in sampled_nodes: - new_task_ins = task_pb2.TaskIns( - task_id="", # Do not set, will be created and set by the DriverAPI - group_id="", - workload_id=workload_id, - task=task_pb2.Task( - producer=node_pb2.Node( - node_id=0, - anonymous=True, - ), - consumer=sampled_node, - legacy_server_message=server_message_proto, - ), - ) - task_ins_list.append(new_task_ins) - - push_task_ins_req = driver_pb2.PushTaskInsRequest(task_ins_list=task_ins_list) - - # ---------------------------------------------------------------------- Driver SDK - push_task_ins_res: driver_pb2.PushTaskInsResponse = driver.push_task_ins( - req=push_task_ins_req - ) - # ---------------------------------------------------------------------- Driver SDK - - print( - f"Scheduled {len(push_task_ins_res.task_ids)} tasks: {push_task_ins_res.task_ids}" - ) - - time.sleep(sleep_time) - - # Wait for results, ignore empty task_ids - task_ids: List[str] = [ - task_id for task_id in push_task_ins_res.task_ids if task_id != "" - ] - all_task_res: List[task_pb2.TaskRes] = [] - while True: - pull_task_res_req = driver_pb2.PullTaskResRequest( - node=node_pb2.Node(node_id=0, anonymous=True), - task_ids=task_ids, - ) - - # ------------------------------------------------------------------ Driver SDK - pull_task_res_res: driver_pb2.PullTaskResResponse = driver.pull_task_res( - req=pull_task_res_req - ) - # ------------------------------------------------------------------ Driver SDK - - task_res_list: List[task_pb2.TaskRes] = pull_task_res_res.task_res_list - print(f"Got {len(task_res_list)} results") - - time.sleep(sleep_time) - - all_task_res += task_res_list - if len(all_task_res) == len(task_ids): - break - - # Collect correct results - node_messages: List[ClientMessage] = [] - for task_res in all_task_res: - if task_res.task.HasField("legacy_client_message"): - node_messages.append(task_res.task.legacy_client_message) - print(f"Received {len(node_messages)} results") - - weights_results: List[Tuple[NDArrays, int]] = [] - metrics_results: List = [] - for node_message in node_messages: - if not node_message.fit_res: - continue - fit_res = node_message.fit_res - # Aggregate only if the status is OK - if fit_res.status.code != Code.OK.value: - continue - weights_results.append( - (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) - ) - metrics_results.append( - (fit_res.num_examples, serde.metrics_from_proto(fit_res.metrics)) - ) - - # Aggregate parameters (FedAvg) - parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results)) - parameters = parameters_aggregated - - # Aggregate metrics - metrics_aggregated = weighted_average(metrics_results) - history.add_metrics_distributed_fit( - server_round=server_round, metrics=metrics_aggregated - ) - print("Round ", server_round, " metrics: ", metrics_aggregated) - - # Slow down the start of the next round - time.sleep(sleep_time) - -print("app_fit: losses_distributed %s", str(history.losses_distributed)) -print("app_fit: metrics_distributed_fit %s", str(history.metrics_distributed_fit)) -print("app_fit: metrics_distributed %s", str(history.metrics_distributed)) -print("app_fit: losses_centralized %s", str(history.losses_centralized)) -print("app_fit: metrics_centralized %s", str(history.metrics_centralized)) - -# -------------------------------------------------------------------------- Driver SDK -driver.disconnect() -# -------------------------------------------------------------------------- Driver SDK -print("Driver disconnected") diff --git a/examples/mxnet-from-centralized-to-federated/README.md b/examples/mxnet-from-centralized-to-federated/README.md index 839d3b16a1cf..2c3f240d8978 100644 --- a/examples/mxnet-from-centralized-to-federated/README.md +++ b/examples/mxnet-from-centralized-to-federated/README.md @@ -1,5 +1,7 @@ # MXNet: From Centralized To Federated +> Note the MXNet project has ended, and is now in [Attic](https://attic.apache.org/projects/mxnet.html). The MXNet GitHub has also [been archived](https://github.com/apache/mxnet). As a result, this example won't be receiving more updates. Using MXNet is no longer recommnended. + This example demonstrates how an already existing centralized MXNet-based machine learning project can be federated with Flower. This introductory example for Flower uses MXNet, but you're not required to be a MXNet expert to run the example. The example will help you to understand how Flower can be used to build federated learning use cases based on an existing MXNet project. diff --git a/examples/mxnet-from-centralized-to-federated/client.py b/examples/mxnet-from-centralized-to-federated/client.py index 3a3a9de146ef..bb666a26508e 100644 --- a/examples/mxnet-from-centralized-to-federated/client.py +++ b/examples/mxnet-from-centralized-to-federated/client.py @@ -1,6 +1,5 @@ """Flower client example using MXNet for MNIST classification.""" - from typing import Dict, List, Tuple import flwr as fl diff --git a/examples/mxnet-from-centralized-to-federated/mxnet_mnist.py b/examples/mxnet-from-centralized-to-federated/mxnet_mnist.py index a53ed018eb48..5cf39da7c9ca 100644 --- a/examples/mxnet-from-centralized-to-federated/mxnet_mnist.py +++ b/examples/mxnet-from-centralized-to-federated/mxnet_mnist.py @@ -5,7 +5,6 @@ https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/image/mnist.html """ - from typing import List, Tuple import mxnet as mx from mxnet import gluon diff --git a/examples/mxnet-from-centralized-to-federated/pyproject.toml b/examples/mxnet-from-centralized-to-federated/pyproject.toml index a0d31f76ebdd..b00b3ddfe412 100644 --- a/examples/mxnet-from-centralized-to-federated/pyproject.toml +++ b/examples/mxnet-from-centralized-to-federated/pyproject.toml @@ -6,11 +6,10 @@ build-backend = "poetry.core.masonry.api" name = "mxnet_example" version = "0.1.0" description = "MXNet example with MNIST and CNN" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" -flwr = ">=1.0,<2.0" -# flwr = { path = "../../", develop = true } # Development -mxnet = "1.6.0" +flwr = "1.6.0" +mxnet = "1.9.1" numpy = "1.23.1" diff --git a/examples/mxnet-from-centralized-to-federated/requirements.txt b/examples/mxnet-from-centralized-to-federated/requirements.txt index 73060e27c70c..8dd6f7150dfd 100644 --- a/examples/mxnet-from-centralized-to-federated/requirements.txt +++ b/examples/mxnet-from-centralized-to-federated/requirements.txt @@ -1,3 +1,3 @@ -flwr>=1.0,<2.0 -mxnet==1.6.0 +flwr==1.6.0 +mxnet==1.9.1 numpy==1.23.1 diff --git a/examples/mxnet-from-centralized-to-federated/server.py b/examples/mxnet-from-centralized-to-federated/server.py index 29cbce1884d1..871aa4e8ec99 100644 --- a/examples/mxnet-from-centralized-to-federated/server.py +++ b/examples/mxnet-from-centralized-to-federated/server.py @@ -1,6 +1,5 @@ """Flower server example.""" - import flwr as fl if __name__ == "__main__": diff --git a/examples/opacus/dp_cifar_client.py b/examples/opacus/dp_cifar_client.py index bab1451ba707..cc30e7728222 100644 --- a/examples/opacus/dp_cifar_client.py +++ b/examples/opacus/dp_cifar_client.py @@ -28,7 +28,7 @@ def load_data(): model = Net() trainloader, testloader, sample_rate = load_data() -fl.client.start_numpy_client( +fl.client.start_client( server_address="127.0.0.1:8080", - client=DPCifarClient(model, trainloader, testloader), + client=DPCifarClient(model, trainloader, testloader).to_client(), ) diff --git a/examples/opacus/dp_cifar_simulation.py b/examples/opacus/dp_cifar_simulation.py index 14a9d037685b..d957caf8785c 100644 --- a/examples/opacus/dp_cifar_simulation.py +++ b/examples/opacus/dp_cifar_simulation.py @@ -1,14 +1,14 @@ import math from collections import OrderedDict -from typing import Callable, Optional, Tuple +from typing import Callable, Dict, Optional, Tuple import flwr as fl import numpy as np import torch import torchvision.transforms as transforms -from opacus.dp_model_inspector import DPModelInspector from torch.utils.data import DataLoader from torchvision.datasets import CIFAR10 +from flwr.common.typing import Scalar from dp_cifar_main import DEVICE, PARAMS, DPCifarClient, Net, test @@ -23,8 +23,6 @@ def client_fn(cid: str) -> fl.client.Client: # Load model. model = Net() # Check model is compatible with Opacus. - # inspector = DPModelInspector() - # print(f"Is the model valid? {inspector.validate(model)}") # Load data partition (divide CIFAR10 into NUM_CLIENTS distinct partitions, using 30% for validation). transform = transforms.Compose( @@ -45,12 +43,14 @@ def client_fn(cid: str) -> fl.client.Client: client_trainloader = DataLoader(client_trainset, PARAMS["batch_size"]) client_testloader = DataLoader(client_testset, PARAMS["batch_size"]) - return DPCifarClient(model, client_trainloader, client_testloader) + return DPCifarClient(model, client_trainloader, client_testloader).to_client() # Define an evaluation function for centralized evaluation (using whole CIFAR10 testset). def get_evaluate_fn() -> Callable[[fl.common.NDArrays], Optional[Tuple[float, float]]]: - def evaluate(weights: fl.common.NDArrays) -> Optional[Tuple[float, float]]: + def evaluate( + server_round: int, parameters: fl.common.NDArrays, config: Dict[str, Scalar] + ): transform = transforms.Compose( [ transforms.ToTensor(), @@ -63,7 +63,7 @@ def evaluate(weights: fl.common.NDArrays) -> Optional[Tuple[float, float]]: state_dict = OrderedDict( { k: torch.tensor(np.atleast_1d(v)) - for k, v in zip(model.state_dict().keys(), weights) + for k, v in zip(model.state_dict().keys(), parameters) } ) model.load_state_dict(state_dict, strict=True) @@ -82,7 +82,7 @@ def main() -> None: client_fn=client_fn, num_clients=NUM_CLIENTS, client_resources={"num_cpus": 1}, - num_rounds=3, + config=fl.server.ServerConfig(num_rounds=3), strategy=fl.server.strategy.FedAvg( fraction_fit=0.1, fraction_evaluate=0.1, evaluate_fn=get_evaluate_fn() ), diff --git a/examples/opacus/pyproject.toml b/examples/opacus/pyproject.toml index af0eaf596fbf..26914fa27aa4 100644 --- a/examples/opacus/pyproject.toml +++ b/examples/opacus/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" name = "flwr_opacus" version = "0.1.0" description = "Differentially Private Federated Learning with Opacus and Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" diff --git a/examples/pytorch-federated-variational-autoencoder/client.py b/examples/pytorch-federated-variational-autoencoder/client.py index ceb55c79f564..fc71f7e70c0b 100644 --- a/examples/pytorch-federated-variational-autoencoder/client.py +++ b/examples/pytorch-federated-variational-autoencoder/client.py @@ -93,7 +93,9 @@ def evaluate(self, parameters, config): loss = test(net, testloader) return float(loss), len(testloader), {} - fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=CifarClient()) + fl.client.start_client( + server_address="127.0.0.1:8080", client=CifarClient().to_client() + ) if __name__ == "__main__": diff --git a/examples/pytorch-federated-variational-autoencoder/pyproject.toml b/examples/pytorch-federated-variational-autoencoder/pyproject.toml index 116140306f62..bc1f85803682 100644 --- a/examples/pytorch-federated-variational-autoencoder/pyproject.toml +++ b/examples/pytorch-federated-variational-autoencoder/pyproject.toml @@ -2,7 +2,7 @@ name = "pytorch_federated_variational_autoencoder" version = "0.1.0" description = "Federated Variational Autoencoder Example" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" diff --git a/examples/pytorch-from-centralized-to-federated/README.md b/examples/pytorch-from-centralized-to-federated/README.md index 40f7f40e5adc..06ee89dddcac 100644 --- a/examples/pytorch-from-centralized-to-federated/README.md +++ b/examples/pytorch-from-centralized-to-federated/README.md @@ -2,7 +2,7 @@ This example demonstrates how an already existing centralized PyTorch-based machine learning project can be federated with Flower. -This introductory example for Flower uses PyTorch, but you're not required to be a PyTorch expert to run the example. The example will help you to understand how Flower can be used to build federated learning use cases based on existing machine learning projects. +This introductory example for Flower uses PyTorch, but you're not required to be a PyTorch expert to run the example. The example will help you to understand how Flower can be used to build federated learning use cases based on existing machine learning projects. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) to download, partition and preprocess the CIFAR-10 dataset. ## Project Setup diff --git a/examples/pytorch-from-centralized-to-federated/__init__.py b/examples/pytorch-from-centralized-to-federated/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/examples/pytorch-from-centralized-to-federated/cifar.py b/examples/pytorch-from-centralized-to-federated/cifar.py index 3c1d67d2f445..277a21da2e70 100644 --- a/examples/pytorch-from-centralized-to-federated/cifar.py +++ b/examples/pytorch-from-centralized-to-federated/cifar.py @@ -6,22 +6,20 @@ https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html """ - # mypy: ignore-errors # pylint: disable=W0223 -from typing import Tuple, Dict +from typing import Tuple import torch import torch.nn as nn import torch.nn.functional as F -import torchvision -import torchvision.transforms as transforms from torch import Tensor -from torchvision.datasets import CIFAR10 +from torch.utils.data import DataLoader +from torchvision.transforms import Compose, ToTensor, Normalize -DATA_ROOT = "./dataset" +from flwr_datasets import FederatedDataset # pylint: disable=unsubscriptable-object @@ -53,19 +51,25 @@ def forward(self, x: Tensor) -> Tensor: return x -def load_data() -> ( - Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader, Dict] -): - """Load CIFAR-10 (training and test set).""" - transform = transforms.Compose( - [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] +def load_data(partition_id: int): + """Load partition CIFAR10 data.""" + fds = FederatedDataset(dataset="cifar10", partitioners={"train": 10}) + partition = fds.load_partition(partition_id) + # Divide data on each node: 80% train, 20% test + partition_train_test = partition.train_test_split(test_size=0.2) + pytorch_transforms = Compose( + [ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] ) - trainset = CIFAR10(DATA_ROOT, train=True, download=True, transform=transform) - trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True) - testset = CIFAR10(DATA_ROOT, train=False, download=True, transform=transform) - testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False) - num_examples = {"trainset": len(trainset), "testset": len(testset)} - return trainloader, testloader, num_examples + + def apply_transforms(batch): + """Apply transforms to the partition from FederatedDataset.""" + batch["img"] = [pytorch_transforms(img) for img in batch["img"]] + return batch + + partition_train_test = partition_train_test.with_transform(apply_transforms) + trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True) + testloader = DataLoader(partition_train_test["test"], batch_size=32) + return trainloader, testloader def train( @@ -87,7 +91,7 @@ def train( for epoch in range(epochs): # loop over the dataset multiple times running_loss = 0.0 for i, data in enumerate(trainloader, 0): - images, labels = data[0].to(device), data[1].to(device) + images, labels = data["img"].to(device), data["label"].to(device) # zero the parameter gradients optimizer.zero_grad() @@ -120,7 +124,7 @@ def test( net.eval() with torch.no_grad(): for data in testloader: - images, labels = data[0].to(device), data[1].to(device) + images, labels = data["img"].to(device), data["label"].to(device) outputs = net(images) loss += criterion(outputs, labels).item() _, predicted = torch.max(outputs.data, 1) # pylint: disable=no-member @@ -133,7 +137,7 @@ def main(): DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print("Centralized PyTorch training") print("Load data") - trainloader, testloader, _ = load_data() + trainloader, testloader = load_data(0) net = Net().to(DEVICE) net.eval() print("Start training") diff --git a/examples/pytorch-from-centralized-to-federated/client.py b/examples/pytorch-from-centralized-to-federated/client.py index 88678e0569b7..9df4739e0aab 100644 --- a/examples/pytorch-from-centralized-to-federated/client.py +++ b/examples/pytorch-from-centralized-to-federated/client.py @@ -1,24 +1,23 @@ """Flower client example using PyTorch for CIFAR-10 image classification.""" - -import os -import sys -import timeit +import argparse from collections import OrderedDict from typing import Dict, List, Tuple -import flwr as fl import numpy as np import torch -import torchvision +from datasets.utils.logging import disable_progress_bar +from torch.utils.data import DataLoader import cifar +import flwr as fl + +disable_progress_bar() + USE_FEDBN: bool = True -# pylint: disable=no-member -DEVICE: str = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -# pylint: enable=no-member +DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Flower Client @@ -28,19 +27,18 @@ class CifarClient(fl.client.NumPyClient): def __init__( self, model: cifar.Net, - trainloader: torch.utils.data.DataLoader, - testloader: torch.utils.data.DataLoader, - num_examples: Dict, + trainloader: DataLoader, + testloader: DataLoader, ) -> None: self.model = model self.trainloader = trainloader self.testloader = testloader - self.num_examples = num_examples def get_parameters(self, config: Dict[str, str]) -> List[np.ndarray]: self.model.train() if USE_FEDBN: - # Return model parameters as a list of NumPy ndarrays, excluding parameters of BN layers when using FedBN + # Return model parameters as a list of NumPy ndarrays, excluding + # parameters of BN layers when using FedBN return [ val.cpu().numpy() for name, val in self.model.state_dict().items() @@ -69,7 +67,7 @@ def fit( # Set model parameters, train model, return updated model parameters self.set_parameters(parameters) cifar.train(self.model, self.trainloader, epochs=1, device=DEVICE) - return self.get_parameters(config={}), self.num_examples["trainset"], {} + return self.get_parameters(config={}), len(self.trainloader.dataset), {} def evaluate( self, parameters: List[np.ndarray], config: Dict[str, str] @@ -77,24 +75,27 @@ def evaluate( # Set model parameters, evaluate model on local test dataset, return result self.set_parameters(parameters) loss, accuracy = cifar.test(self.model, self.testloader, device=DEVICE) - return float(loss), self.num_examples["testset"], {"accuracy": float(accuracy)} + return float(loss), len(self.testloader.dataset), {"accuracy": float(accuracy)} def main() -> None: """Load data, start CifarClient.""" + parser = argparse.ArgumentParser(description="Flower") + parser.add_argument("--partition-id", type=int, required=True, choices=range(0, 10)) + args = parser.parse_args() # Load data - trainloader, testloader, num_examples = cifar.load_data() + trainloader, testloader = cifar.load_data(args.partition_id) # Load model model = cifar.Net().to(DEVICE).train() # Perform a single forward pass to properly initialize BatchNorm - _ = model(next(iter(trainloader))[0].to(DEVICE)) + _ = model(next(iter(trainloader))["img"].to(DEVICE)) # Start client - client = CifarClient(model, trainloader, testloader, num_examples) - fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=client) + client = CifarClient(model, trainloader, testloader).to_client() + fl.client.start_client(server_address="127.0.0.1:8080", client=client) if __name__ == "__main__": diff --git a/examples/pytorch-from-centralized-to-federated/pyproject.toml b/examples/pytorch-from-centralized-to-federated/pyproject.toml index 73999a9e6cd4..3d1559e3a515 100644 --- a/examples/pytorch-from-centralized-to-federated/pyproject.toml +++ b/examples/pytorch-from-centralized-to-federated/pyproject.toml @@ -6,10 +6,11 @@ build-backend = "poetry.masonry.api" name = "quickstart-pytorch" version = "0.1.0" description = "PyTorch: From Centralized To Federated with Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" flwr = ">=1.0,<2.0" +flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" } torch = "1.13.1" torchvision = "0.14.1" diff --git a/examples/pytorch-from-centralized-to-federated/requirements.txt b/examples/pytorch-from-centralized-to-federated/requirements.txt index f3caddbc875e..ba4afad9c288 100644 --- a/examples/pytorch-from-centralized-to-federated/requirements.txt +++ b/examples/pytorch-from-centralized-to-federated/requirements.txt @@ -1,3 +1,4 @@ flwr>=1.0, <2.0 +flwr-datasets[vision]>=0.0.2, <1.0.0 torch==1.13.1 torchvision==0.14.1 diff --git a/examples/pytorch-from-centralized-to-federated/run.sh b/examples/pytorch-from-centralized-to-federated/run.sh index c64f362086aa..6ddf6ad476b4 100755 --- a/examples/pytorch-from-centralized-to-federated/run.sh +++ b/examples/pytorch-from-centralized-to-federated/run.sh @@ -4,9 +4,9 @@ echo "Starting server" python server.py & sleep 3 # Sleep for 3s to give the server enough time to start -for i in `seq 0 1`; do +for i in $(seq 0 1); do echo "Starting client $i" - python client.py & + python client.py --partition-id $i & done # This will allow you to use CTRL+C to stop all background processes diff --git a/examples/pytorch-from-centralized-to-federated/server.py b/examples/pytorch-from-centralized-to-federated/server.py index 29cbce1884d1..5190d690dc20 100644 --- a/examples/pytorch-from-centralized-to-federated/server.py +++ b/examples/pytorch-from-centralized-to-federated/server.py @@ -1,10 +1,27 @@ """Flower server example.""" +from typing import List, Tuple import flwr as fl +from flwr.common import Metrics -if __name__ == "__main__": - fl.server.start_server( - server_address="0.0.0.0:8080", - config=fl.server.ServerConfig(num_rounds=3), - ) + +# Define metric aggregation function +def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: + # Multiply accuracy of each client by number of examples used + accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics] + examples = [num_examples for num_examples, _ in metrics] + + # Aggregate and return custom metric (weighted average) + return {"accuracy": sum(accuracies) / sum(examples)} + + +# Define strategy +strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=weighted_average) + +# Start Flower server +fl.server.start_server( + server_address="0.0.0.0:8080", + config=fl.server.ServerConfig(num_rounds=10), + strategy=strategy, +) diff --git a/examples/quickstart-cpp/driver.py b/examples/quickstart-cpp/driver.py index 037623ee77cf..f19cf0e9bd98 100644 --- a/examples/quickstart-cpp/driver.py +++ b/examples/quickstart-cpp/driver.py @@ -3,7 +3,7 @@ # Start Flower server for three rounds of federated learning if __name__ == "__main__": - fl.driver.start_driver( + fl.server.start_driver( server_address="0.0.0.0:9091", config=fl.server.ServerConfig(num_rounds=3), strategy=FedAvgCpp(), diff --git a/examples/quickstart-fastai/client.py b/examples/quickstart-fastai/client.py index a88abbe525dc..6bb2a751d544 100644 --- a/examples/quickstart-fastai/client.py +++ b/examples/quickstart-fastai/client.py @@ -43,7 +43,7 @@ def evaluate(self, parameters, config): # Start Flower client -fl.client.start_numpy_client( +fl.client.start_client( server_address="127.0.0.1:8080", - client=FlowerClient(), + client=FlowerClient().to_client(), ) diff --git a/examples/quickstart-fastai/pyproject.toml b/examples/quickstart-fastai/pyproject.toml index ffaa97267493..19a25291a6af 100644 --- a/examples/quickstart-fastai/pyproject.toml +++ b/examples/quickstart-fastai/pyproject.toml @@ -6,9 +6,11 @@ build-backend = "poetry.core.masonry.api" name = "quickstart-fastai" version = "0.1.0" description = "Fastai Federated Learning Quickstart with Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] -python = ">=3.8,<3.10" -flwr = "^1.0.0" -fastai = "^2.7.10" +python = ">=3.8,<3.11" +flwr = ">=1.0,<2.0" +fastai = "2.7.14" +torch = "2.2.0" +torchvision = "0.17.0" diff --git a/examples/quickstart-fastai/requirements.txt b/examples/quickstart-fastai/requirements.txt index 0a7f315018a2..9c6e8d77293a 100644 --- a/examples/quickstart-fastai/requirements.txt +++ b/examples/quickstart-fastai/requirements.txt @@ -1,3 +1,4 @@ -fastai~=2.7.12 -flwr~=1.4.0 -torch~=2.0.1 +flwr>=1.0, <2.0 +fastai==2.7.14 +torch==2.2.0 +torchvision==0.17.0 diff --git a/examples/quickstart-huggingface/README.md b/examples/quickstart-huggingface/README.md index c1e3cc4edc06..ce7790cd4af5 100644 --- a/examples/quickstart-huggingface/README.md +++ b/examples/quickstart-huggingface/README.md @@ -1,6 +1,6 @@ # Federated HuggingFace Transformers using Flower and PyTorch -This introductory example to using [HuggingFace](https://huggingface.co) Transformers with Flower with PyTorch. This example has been extended from the [quickstart-pytorch](https://flower.dev/docs/examples/quickstart-pytorch.html) example. The training script closely follows the [HuggingFace course](https://huggingface.co/course/chapter3?fw=pt), so you are encouraged to check that out for detailed explaination for the transformer pipeline. +This introductory example to using [HuggingFace](https://huggingface.co) Transformers with Flower with PyTorch. This example has been extended from the [quickstart-pytorch](https://flower.ai/docs/examples/quickstart-pytorch.html) example. The training script closely follows the [HuggingFace course](https://huggingface.co/course/chapter3?fw=pt), so you are encouraged to check that out for a detailed explanation of the transformer pipeline. Like `quickstart-pytorch`, running this example in itself is also meant to be quite easy. @@ -62,13 +62,13 @@ Now you are ready to start the Flower clients which will participate in the lear Start client 1 in the first terminal: ```shell -python3 client.py +python3 client.py --partition-id 0 ``` Start client 2 in the second terminal: ```shell -python3 client.py +python3 client.py --partition-id 1 ``` You will see that PyTorch is starting a federated training. diff --git a/examples/quickstart-huggingface/client.py b/examples/quickstart-huggingface/client.py index 8717d710ad9c..9be08d0cbcf4 100644 --- a/examples/quickstart-huggingface/client.py +++ b/examples/quickstart-huggingface/client.py @@ -1,58 +1,48 @@ -from collections import OrderedDict +import argparse import warnings +from collections import OrderedDict import flwr as fl import torch -import numpy as np - -import random -from torch.utils.data import DataLoader - -from datasets import load_dataset from evaluate import load as load_metric - -from transformers import AutoTokenizer, DataCollatorWithPadding +from torch.optim import AdamW +from torch.utils.data import DataLoader from transformers import AutoModelForSequenceClassification -from transformers import AdamW +from transformers import AutoTokenizer, DataCollatorWithPadding + +from flwr_datasets import FederatedDataset warnings.filterwarnings("ignore", category=UserWarning) DEVICE = torch.device("cpu") CHECKPOINT = "distilbert-base-uncased" # transformer model checkpoint -def load_data(): +def load_data(partition_id): """Load IMDB data (training and eval)""" - raw_datasets = load_dataset("imdb") - raw_datasets = raw_datasets.shuffle(seed=42) - - # remove unnecessary data split - del raw_datasets["unsupervised"] + fds = FederatedDataset(dataset="imdb", partitioners={"train": 1_000}) + partition = fds.load_partition(partition_id) + # Divide data: 80% train, 20% test + partition_train_test = partition.train_test_split(test_size=0.2) tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT) def tokenize_function(examples): return tokenizer(examples["text"], truncation=True) - # random 100 samples - population = random.sample(range(len(raw_datasets["train"])), 100) - - tokenized_datasets = raw_datasets.map(tokenize_function, batched=True) - tokenized_datasets["train"] = tokenized_datasets["train"].select(population) - tokenized_datasets["test"] = tokenized_datasets["test"].select(population) - - tokenized_datasets = tokenized_datasets.remove_columns("text") - tokenized_datasets = tokenized_datasets.rename_column("label", "labels") + partition_train_test = partition_train_test.map(tokenize_function, batched=True) + partition_train_test = partition_train_test.remove_columns("text") + partition_train_test = partition_train_test.rename_column("label", "labels") data_collator = DataCollatorWithPadding(tokenizer=tokenizer) trainloader = DataLoader( - tokenized_datasets["train"], + partition_train_test["train"], shuffle=True, batch_size=32, collate_fn=data_collator, ) testloader = DataLoader( - tokenized_datasets["test"], batch_size=32, collate_fn=data_collator + partition_train_test["test"], batch_size=32, collate_fn=data_collator ) return trainloader, testloader @@ -88,12 +78,12 @@ def test(net, testloader): return loss, accuracy -def main(): +def main(partition_id): net = AutoModelForSequenceClassification.from_pretrained( CHECKPOINT, num_labels=2 ).to(DEVICE) - trainloader, testloader = load_data() + trainloader, testloader = load_data(partition_id) # Flower client class IMDBClient(fl.client.NumPyClient): @@ -118,8 +108,20 @@ def evaluate(self, parameters, config): return float(loss), len(testloader), {"accuracy": float(accuracy)} # Start client - fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=IMDBClient()) + fl.client.start_client( + server_address="127.0.0.1:8080", client=IMDBClient().to_client() + ) if __name__ == "__main__": - main() + parser = argparse.ArgumentParser(description="Flower") + parser.add_argument( + "--partition-id", + choices=list(range(1_000)), + required=True, + type=int, + help="Partition of the dataset divided into 1,000 iid partitions created " + "artificially.", + ) + partition_id = parser.parse_args().partition_id + main(partition_id) diff --git a/examples/quickstart-huggingface/pyproject.toml b/examples/quickstart-huggingface/pyproject.toml index eb9687c5152c..2b46804d7b45 100644 --- a/examples/quickstart-huggingface/pyproject.toml +++ b/examples/quickstart-huggingface/pyproject.toml @@ -7,13 +7,14 @@ name = "quickstart-huggingface" version = "0.1.0" description = "Hugging Face Transformers Federated Learning Quickstart with Flower" authors = [ - "The Flower Authors ", - "Kaushik Amar Das " + "The Flower Authors ", + "Kaushik Amar Das ", ] [tool.poetry.dependencies] python = ">=3.8,<3.11" flwr = ">=1.0,<2.0" +flwr-datasets = ">=0.0.2,<1.0.0" torch = ">=1.13.1,<2.0" transformers = ">=4.30.0,<5.0" evaluate = ">=0.4.0,<1.0" diff --git a/examples/quickstart-huggingface/requirements.txt b/examples/quickstart-huggingface/requirements.txt index aeb2d13fc4a4..3cd5735625ba 100644 --- a/examples/quickstart-huggingface/requirements.txt +++ b/examples/quickstart-huggingface/requirements.txt @@ -1,4 +1,5 @@ flwr>=1.0, <2.0 +flwr-datasets>=0.0.2, <1.0.0 torch>=1.13.1, <2.0 transformers>=4.30.0, <5.0 evaluate>=0.4.0, <1.0 diff --git a/examples/quickstart-huggingface/run.sh b/examples/quickstart-huggingface/run.sh index c64f362086aa..fa989eab1471 100755 --- a/examples/quickstart-huggingface/run.sh +++ b/examples/quickstart-huggingface/run.sh @@ -6,7 +6,7 @@ sleep 3 # Sleep for 3s to give the server enough time to start for i in `seq 0 1`; do echo "Starting client $i" - python client.py & + python client.py --partition-id ${i}& done # This will allow you to use CTRL+C to stop all background processes diff --git a/examples/quickstart-jax/README.md b/examples/quickstart-jax/README.md index 9208337b86de..836adf558d88 100644 --- a/examples/quickstart-jax/README.md +++ b/examples/quickstart-jax/README.md @@ -52,7 +52,7 @@ pip install -r requirements.txt ## Run JAX Federated -This JAX example is based on the [Linear Regression with JAX](https://coax.readthedocs.io/en/latest/examples/linear_regression/jax.html) tutorial and uses a sklearn dataset (generating a random dataset for a regression pronlem). Feel free to consult the tutorial if you want to get a better understanding of JAX. If you play around with the dataset, please keep in mind that the data samples are generated randomly depending on the settings being done while calling the dataset function. Please checkout out the [scikit-learn tutorial for further information](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_regression.html). The file `jax_training.py` contains all the steps that are described in the tutorial. It loads the train and test dataset and a linear regression model, trains the model with the training set, and evaluates the trained model on the test set. +This JAX example is based on the [Linear Regression with JAX](https://coax.readthedocs.io/en/latest/examples/linear_regression/jax.html) tutorial and uses a sklearn dataset (generating a random dataset for a regression problem). Feel free to consult the tutorial if you want to get a better understanding of JAX. If you play around with the dataset, please keep in mind that the data samples are generated randomly depending on the settings being done while calling the dataset function. Please checkout out the [scikit-learn tutorial for further information](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_regression.html). The file `jax_training.py` contains all the steps that are described in the tutorial. It loads the train and test dataset and a linear regression model, trains the model with the training set, and evaluates the trained model on the test set. The only things we need are a simple Flower server (in `server.py`) and a Flower client (in `client.py`). The Flower client basically takes model and training code tells Flower how to call it. diff --git a/examples/quickstart-jax/client.py b/examples/quickstart-jax/client.py index f9b056276deb..2257a3d6daa3 100644 --- a/examples/quickstart-jax/client.py +++ b/examples/quickstart-jax/client.py @@ -1,6 +1,5 @@ """Flower client example using JAX for linear regression.""" - from typing import Dict, List, Tuple, Callable import flwr as fl @@ -52,4 +51,6 @@ def evaluate( # Start Flower client -fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=FlowerClient()) +fl.client.start_client( + server_address="127.0.0.1:8080", client=FlowerClient().to_client() +) diff --git a/examples/quickstart-jax/jax_training.py b/examples/quickstart-jax/jax_training.py index 2b523a08516e..a2e23a0927bc 100644 --- a/examples/quickstart-jax/jax_training.py +++ b/examples/quickstart-jax/jax_training.py @@ -7,7 +7,6 @@ please read the JAX documentation or the mentioned tutorial. """ - from typing import Dict, List, Tuple, Callable import jax import jax.numpy as jnp diff --git a/examples/quickstart-jax/pyproject.toml b/examples/quickstart-jax/pyproject.toml index 41b4462d0a14..c956191369b5 100644 --- a/examples/quickstart-jax/pyproject.toml +++ b/examples/quickstart-jax/pyproject.toml @@ -2,7 +2,7 @@ name = "jax_example" version = "0.1.0" description = "JAX example training a linear regression model with federated learning" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" diff --git a/examples/quickstart-mlcube/client.py b/examples/quickstart-mlcube/client.py index 0a1962d8da8a..46ddd45f52ce 100644 --- a/examples/quickstart-mlcube/client.py +++ b/examples/quickstart-mlcube/client.py @@ -43,8 +43,9 @@ def main(): os.path.dirname(os.path.abspath(__file__)), "workspaces", workspace_name ) - fl.client.start_numpy_client( - server_address="0.0.0.0:8080", client=MLCubeClient(workspace=workspace) + fl.client.start_client( + server_address="0.0.0.0:8080", + client=MLCubeClient(workspace=workspace).to_client(), ) diff --git a/examples/quickstart-mlcube/pyproject.toml b/examples/quickstart-mlcube/pyproject.toml index 0d42fc3b2898..a2862bd5ebb7 100644 --- a/examples/quickstart-mlcube/pyproject.toml +++ b/examples/quickstart-mlcube/pyproject.toml @@ -6,13 +6,13 @@ build-backend = "poetry.masonry.api" name = "quickstart-ml-cube" version = "0.1.0" description = "Keras Federated Learning Quickstart with Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" flwr = ">=1.0,<2.0" # For development: { path = "../../", develop = true } -tensorflow-cpu = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\""} -tensorflow-macos = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\""} +tensorflow-cpu = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\"" } +tensorflow-macos = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\"" } mlcube = "0.0.9" mlcube-docker = "0.0.9" tensorflow-estimator = ">=2.9.1,<2.11.1 || >2.11.1" diff --git a/examples/quickstart-mlx/README.md b/examples/quickstart-mlx/README.md index d94a87a014f7..cca55bcb946a 100644 --- a/examples/quickstart-mlx/README.md +++ b/examples/quickstart-mlx/README.md @@ -66,19 +66,19 @@ following commands. Start a first client in the first terminal: ```shell -python3 client.py --node-id 0 +python3 client.py --partition-id 0 ``` And another one in the second terminal: ```shell -python3 client.py --node-id 1 +python3 client.py --partition-id 1 ``` If you want to utilize your GPU, you can use the `--gpu` argument: ```shell -python3 client.py --gpu --node-id 2 +python3 client.py --gpu --partition-id 2 ``` Note that you can start many more clients if you want, but each will have to be in its own terminal. @@ -96,7 +96,7 @@ We will use `flwr_datasets` to easily download and partition the `MNIST` dataset ```python fds = FederatedDataset(dataset="mnist", partitioners={"train": 3}) -partition = fds.load_partition(node_id = args.node_id) +partition = fds.load_partition(partition_id = args.partition_id) partition_splits = partition.train_test_split(test_size=0.2) partition_splits['train'].set_format("numpy") diff --git a/examples/quickstart-mlx/client.py b/examples/quickstart-mlx/client.py index 3b506399a5f1..faba2b94d6bd 100644 --- a/examples/quickstart-mlx/client.py +++ b/examples/quickstart-mlx/client.py @@ -89,7 +89,7 @@ def evaluate(self, parameters, config): parser = argparse.ArgumentParser("Train a simple MLP on MNIST with MLX.") parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.") parser.add_argument( - "--node-id", + "--partition-id", choices=[0, 1, 2], type=int, help="Partition of the dataset divided into 3 iid partitions created artificially.", @@ -106,7 +106,7 @@ def evaluate(self, parameters, config): learning_rate = 1e-1 fds = FederatedDataset(dataset="mnist", partitioners={"train": 3}) - partition = fds.load_partition(node_id=args.node_id) + partition = fds.load_partition(partition_id=args.partition_id) partition_splits = partition.train_test_split(test_size=0.2) partition_splits["train"].set_format("numpy") diff --git a/examples/quickstart-mlx/pyproject.toml b/examples/quickstart-mlx/pyproject.toml index deb541c5ba9c..752040b6aaa9 100644 --- a/examples/quickstart-mlx/pyproject.toml +++ b/examples/quickstart-mlx/pyproject.toml @@ -6,11 +6,11 @@ build-backend = "poetry.core.masonry.api" name = "quickstart-mlx" version = "0.1.0" description = "MLX Federated Learning Quickstart with Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" flwr = ">=1.0,<2.0" mlx = "==0.0.3" numpy = "==1.24.4" -flwr-datasets = {extras = ["vision"], version = "^0.0.2"} +flwr-datasets = { extras = ["vision"], version = "^0.0.2" } diff --git a/examples/quickstart-mlx/run.sh b/examples/quickstart-mlx/run.sh index 70281049517d..40d211848c07 100755 --- a/examples/quickstart-mlx/run.sh +++ b/examples/quickstart-mlx/run.sh @@ -8,7 +8,7 @@ sleep 3 # Sleep for 3s to give the server enough time to start for i in $(seq 0 1); do echo "Starting client $i" - python client.py --node-id $i & + python client.py --partition-id $i & done # Enable CTRL+C to stop all background processes diff --git a/examples/quickstart-monai/.gitignore b/examples/quickstart-monai/.gitignore new file mode 100644 index 000000000000..a218cab9669e --- /dev/null +++ b/examples/quickstart-monai/.gitignore @@ -0,0 +1 @@ +MedNIST* diff --git a/examples/quickstart-monai/README.md b/examples/quickstart-monai/README.md new file mode 100644 index 000000000000..4a9afef4f86a --- /dev/null +++ b/examples/quickstart-monai/README.md @@ -0,0 +1,85 @@ +# Flower Example using MONAI + +This introductory example to Flower uses MONAI, but deep knowledge of MONAI is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. +Running this example in itself is quite easy. + +[MONAI](https://docs.monai.io/en/latest/index.html)(Medical Open Network for AI) is a PyTorch-based, open-source framework for deep learning in healthcare imaging, part of the PyTorch Ecosystem. + +Its ambitions are: + +- developing a community of academic, industrial and clinical researchers collaborating on a common foundation; + +- creating state-of-the-art, end-to-end training workflows for healthcare imaging; + +- providing researchers with an optimized and standardized way to create and evaluate deep learning models. + +## Project Setup + +Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you: + +```shell +git clone --depth=1 https://github.com/adap/flower.git _tmp && mv _tmp/examples/quickstart-monai . && rm -rf _tmp && cd quickstart-monai +``` + +This will create a new directory called `quickstart-monai` containing the following files: + +```shell +-- pyproject.toml +-- requirements.txt +-- client.py +-- data.py +-- model.py +-- server.py +-- README.md +``` + +### Installing Dependencies + +Project dependencies (such as `monai` and `flwr`) are defined in `pyproject.toml` and `requirements.txt`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)) or [pip](https://pip.pypa.io/en/latest/development/), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences. + +#### Poetry + +```shell +poetry install +poetry shell +``` + +Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command: + +```shell +poetry run python3 -c "import flwr" +``` + +If you don't see any errors you're good to go! + +#### pip + +Write the command below in your terminal to install the dependencies according to the configuration file requirements.txt. + +```shell +pip install -r requirements.txt +``` + +## Run Federated Learning with MONAI and Flower + +Afterwards you are ready to start the Flower server as well as the clients. You can simply start the server in a terminal as follows: + +```shell +python3 server.py +``` + +Now you are ready to start the Flower clients which will participate in the learning. To do so simply open two more terminal windows and run the following commands. Clients will train a [DenseNet121](https://docs.monai.io/en/stable/networks.html#densenet121) from MONAI. If a GPU is present in your system, clients will use it. + +Start client 1 in the first terminal: + +```shell +python3 client.py --partition-id 0 +``` + +Start client 2 in the second terminal: + +```shell +python3 client.py --partition-id 1 +``` + +You will see that the federated training is starting. Look at the [code](https://github.com/adap/flower/tree/main/examples/quickstart-monai) for a detailed explanation. diff --git a/examples/quickstart-monai/client.py b/examples/quickstart-monai/client.py new file mode 100644 index 000000000000..0ed943da83cc --- /dev/null +++ b/examples/quickstart-monai/client.py @@ -0,0 +1,61 @@ +import argparse +import warnings +from collections import OrderedDict + +import torch +from data import load_data +from model import test, train +from monai.networks.nets.densenet import DenseNet121 + +import flwr as fl + +warnings.filterwarnings("ignore", category=UserWarning) +DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +# Define Flower client +class FlowerClient(fl.client.NumPyClient): + def __init__(self, net, trainloader, testloader, device): + self.net = net + self.trainloader = trainloader + self.testloader = testloader + self.device = device + + def get_parameters(self, config): + return [val.cpu().numpy() for _, val in self.net.state_dict().items()] + + def set_parameters(self, parameters): + params_dict = zip(self.net.state_dict().keys(), parameters) + state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) + self.net.load_state_dict(state_dict, strict=True) + + def fit(self, parameters, config): + self.set_parameters(parameters) + train(self.net, self.trainloader, epoch_num=1, device=self.device) + return self.get_parameters(config={}), len(self.trainloader), {} + + def evaluate(self, parameters, config): + self.set_parameters(parameters) + loss, accuracy = test(self.net, self.testloader, self.device) + return loss, len(self.testloader), {"accuracy": accuracy} + + +if __name__ == "__main__": + total_partitions = 10 + parser = argparse.ArgumentParser() + parser.add_argument( + "--partition-id", type=int, choices=range(total_partitions), required=True + ) + args = parser.parse_args() + + # Load model and data (simple CNN, CIFAR-10) + trainloader, _, testloader, num_class = load_data( + total_partitions, args.partition_id + ) + net = DenseNet121(spatial_dims=2, in_channels=1, out_channels=num_class).to(DEVICE) + + # Start Flower client + fl.client.start_numpy_client( + server_address="127.0.0.1:8080", + client=FlowerClient(net, trainloader, testloader, DEVICE), + ) diff --git a/examples/quickstart-monai/data.py b/examples/quickstart-monai/data.py new file mode 100644 index 000000000000..d184476522e8 --- /dev/null +++ b/examples/quickstart-monai/data.py @@ -0,0 +1,158 @@ +import os +import tarfile +from urllib import request + +import numpy as np +from monai.data import DataLoader, Dataset +from monai.transforms import ( + Compose, + EnsureChannelFirst, + LoadImage, + RandFlip, + RandRotate, + RandZoom, + ScaleIntensity, + ToTensor, +) + + +def _partition(files_list, labels_list, num_shards, index): + total_size = len(files_list) + assert total_size == len( + labels_list + ), f"List of datapoints and labels must be of the same length" + shard_size = total_size // num_shards + + # Calculate start and end indices for the shard + start_idx = index * shard_size + if index == num_shards - 1: + # Last shard takes the remainder + end_idx = total_size + else: + end_idx = start_idx + shard_size + + # Create a subset for the shard + files = files_list[start_idx:end_idx] + labels = labels_list[start_idx:end_idx] + return files, labels + + +def load_data(num_shards, index): + image_file_list, image_label_list, _, num_class = _download_data() + + # Get partition given index + files_list, labels_list = _partition( + image_file_list, image_label_list, num_shards, index + ) + + trainX, trainY, valX, valY, testX, testY = _split_data( + files_list, labels_list, len(files_list) + ) + train_transforms, val_transforms = _get_transforms() + + train_ds = MedNISTDataset(trainX, trainY, train_transforms) + train_loader = DataLoader(train_ds, batch_size=300, shuffle=True) + + val_ds = MedNISTDataset(valX, valY, val_transforms) + val_loader = DataLoader(val_ds, batch_size=300) + + test_ds = MedNISTDataset(testX, testY, val_transforms) + test_loader = DataLoader(test_ds, batch_size=300) + + return train_loader, val_loader, test_loader, num_class + + +class MedNISTDataset(Dataset): + def __init__(self, image_files, labels, transforms): + self.image_files = image_files + self.labels = labels + self.transforms = transforms + + def __len__(self): + return len(self.image_files) + + def __getitem__(self, index): + return self.transforms(self.image_files[index]), self.labels[index] + + +def _download_data(): + data_dir = "./MedNIST/" + _download_and_extract( + "https://dl.dropboxusercontent.com/s/5wwskxctvcxiuea/MedNIST.tar.gz", + os.path.join(data_dir), + ) + + class_names = sorted( + [x for x in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, x))] + ) + num_class = len(class_names) + image_files = [ + [ + os.path.join(data_dir, class_name, x) + for x in os.listdir(os.path.join(data_dir, class_name)) + ] + for class_name in class_names + ] + image_file_list = [] + image_label_list = [] + for i, class_name in enumerate(class_names): + image_file_list.extend(image_files[i]) + image_label_list.extend([i] * len(image_files[i])) + num_total = len(image_label_list) + return image_file_list, image_label_list, num_total, num_class + + +def _split_data(image_file_list, image_label_list, num_total): + valid_frac, test_frac = 0.1, 0.1 + trainX, trainY = [], [] + valX, valY = [], [] + testX, testY = [], [] + + for i in range(num_total): + rann = np.random.random() + if rann < valid_frac: + valX.append(image_file_list[i]) + valY.append(image_label_list[i]) + elif rann < test_frac + valid_frac: + testX.append(image_file_list[i]) + testY.append(image_label_list[i]) + else: + trainX.append(image_file_list[i]) + trainY.append(image_label_list[i]) + + return trainX, trainY, valX, valY, testX, testY + + +def _get_transforms(): + train_transforms = Compose( + [ + LoadImage(image_only=True), + EnsureChannelFirst(), + ScaleIntensity(), + RandRotate(range_x=15, prob=0.5, keep_size=True), + RandFlip(spatial_axis=0, prob=0.5), + RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5, keep_size=True), + ToTensor(), + ] + ) + + val_transforms = Compose( + [LoadImage(image_only=True), EnsureChannelFirst(), ScaleIntensity(), ToTensor()] + ) + + return train_transforms, val_transforms + + +def _download_and_extract(url, dest_folder): + if not os.path.isdir(dest_folder): + # Download the tar.gz file + tar_gz_filename = url.split("/")[-1] + if not os.path.isfile(tar_gz_filename): + with request.urlopen(url) as response, open( + tar_gz_filename, "wb" + ) as out_file: + out_file.write(response.read()) + + # Extract the tar.gz file + with tarfile.open(tar_gz_filename, "r:gz") as tar_ref: + tar_ref.extractall() diff --git a/examples/quickstart-monai/model.py b/examples/quickstart-monai/model.py new file mode 100644 index 000000000000..4c74d50553e4 --- /dev/null +++ b/examples/quickstart-monai/model.py @@ -0,0 +1,33 @@ +import torch + + +def train(model, train_loader, epoch_num, device): + loss_function = torch.nn.CrossEntropyLoss() + optimizer = torch.optim.Adam(model.parameters(), 1e-5) + for _ in range(epoch_num): + model.train() + for inputs, labels in train_loader: + optimizer.zero_grad() + loss_function(model(inputs.to(device)), labels.to(device)).backward() + optimizer.step() + + +def test(model, test_loader, device): + model.eval() + loss = 0.0 + y_true = list() + y_pred = list() + loss_function = torch.nn.CrossEntropyLoss() + with torch.no_grad(): + for test_images, test_labels in test_loader: + out = model(test_images.to(device)) + test_labels = test_labels.to(device) + loss += loss_function(out, test_labels).item() + pred = out.argmax(dim=1) + for i in range(len(pred)): + y_true.append(test_labels[i].item()) + y_pred.append(pred[i].item()) + accuracy = sum([1 if t == p else 0 for t, p in zip(y_true, y_pred)]) / len( + test_loader.dataset + ) + return loss, accuracy diff --git a/examples/mt-pytorch-callable/pyproject.toml b/examples/quickstart-monai/pyproject.toml similarity index 51% rename from examples/mt-pytorch-callable/pyproject.toml rename to examples/quickstart-monai/pyproject.toml index 0d1a91836006..66a56ee2270b 100644 --- a/examples/mt-pytorch-callable/pyproject.toml +++ b/examples/quickstart-monai/pyproject.toml @@ -3,14 +3,17 @@ requires = ["poetry-core>=1.4.0"] build-backend = "poetry.core.masonry.api" [tool.poetry] -name = "quickstart-pytorch" +name = "quickstart-monai" version = "0.1.0" -description = "PyTorch Federated Learning Quickstart with Flower" +description = "MONAI Federated Learning Quickstart with Flower" authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" -flwr = { path = "../../", develop = true, extras = ["simulation", "rest"] } +flwr = ">=1.0,<2.0" torch = "1.13.1" -torchvision = "0.14.1" tqdm = "4.65.0" +scikit-learn = "1.3.1" +monai = { version = "1.3.0", extras=["gdown", "nibabel", "tqdm", "itk"] } +numpy = "1.24.4" +pillow = "10.2.0" diff --git a/examples/quickstart-monai/requirements.txt b/examples/quickstart-monai/requirements.txt new file mode 100644 index 000000000000..e3f1e463c629 --- /dev/null +++ b/examples/quickstart-monai/requirements.txt @@ -0,0 +1,7 @@ +flwr>=1.0, <2.0 +torch==1.13.1 +tqdm==4.65.0 +scikit-learn==1.3.1 +monai[gdown,nibabel,tqdm,itk]==1.3.0 +numpy==1.24.4 +pillow==10.2.0 diff --git a/examples/mt-pytorch-callable/run.sh b/examples/quickstart-monai/run.sh similarity index 74% rename from examples/mt-pytorch-callable/run.sh rename to examples/quickstart-monai/run.sh index d2bf34f834b1..1da60bccb86d 100755 --- a/examples/mt-pytorch-callable/run.sh +++ b/examples/quickstart-monai/run.sh @@ -2,8 +2,7 @@ set -e cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/ -# Download the CIFAR-10 dataset -python -c "from torchvision.datasets import CIFAR10; CIFAR10('./data', download=True)" +python -c "from data import _download_data; _download_data()" echo "Starting server" python server.py & @@ -11,7 +10,7 @@ sleep 3 # Sleep for 3s to give the server enough time to start for i in `seq 0 1`; do echo "Starting client $i" - python client.py & + python client.py --partition-id $i & done # Enable CTRL+C to stop all background processes diff --git a/examples/mt-pytorch-callable/server.py b/examples/quickstart-monai/server.py similarity index 100% rename from examples/mt-pytorch-callable/server.py rename to examples/quickstart-monai/server.py diff --git a/examples/quickstart-mxnet/README.md b/examples/quickstart-mxnet/README.md index 930cec5acdfd..37e01ef2707c 100644 --- a/examples/quickstart-mxnet/README.md +++ b/examples/quickstart-mxnet/README.md @@ -1,5 +1,7 @@ # Flower Example using MXNet +> Note the MXNet project has ended, and is now in [Attic](https://attic.apache.org/projects/mxnet.html). The MXNet GitHub has also [been archived](https://github.com/apache/mxnet). As a result, this example won't be receiving more updates. Using MXNet is no longer recommnended. + This example demonstrates how to run a MXNet machine learning project federated with Flower. This introductory example for Flower uses MXNet, but you're not required to be a MXNet expert to run the example. The example will help you to understand how Flower can be used to build federated learning use cases based on an existing MXNet projects. diff --git a/examples/quickstart-mxnet/client.py b/examples/quickstart-mxnet/client.py index b0c937f7350f..6c2b2e99775d 100644 --- a/examples/quickstart-mxnet/client.py +++ b/examples/quickstart-mxnet/client.py @@ -5,7 +5,6 @@ https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/image/mnist.html """ - import flwr as fl import numpy as np import mxnet as mx diff --git a/examples/quickstart-mxnet/pyproject.toml b/examples/quickstart-mxnet/pyproject.toml index a0d31f76ebdd..b00b3ddfe412 100644 --- a/examples/quickstart-mxnet/pyproject.toml +++ b/examples/quickstart-mxnet/pyproject.toml @@ -6,11 +6,10 @@ build-backend = "poetry.core.masonry.api" name = "mxnet_example" version = "0.1.0" description = "MXNet example with MNIST and CNN" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" -flwr = ">=1.0,<2.0" -# flwr = { path = "../../", develop = true } # Development -mxnet = "1.6.0" +flwr = "1.6.0" +mxnet = "1.9.1" numpy = "1.23.1" diff --git a/examples/quickstart-mxnet/requirements.txt b/examples/quickstart-mxnet/requirements.txt index 73060e27c70c..8dd6f7150dfd 100644 --- a/examples/quickstart-mxnet/requirements.txt +++ b/examples/quickstart-mxnet/requirements.txt @@ -1,3 +1,3 @@ -flwr>=1.0,<2.0 -mxnet==1.6.0 +flwr==1.6.0 +mxnet==1.9.1 numpy==1.23.1 diff --git a/examples/quickstart-mxnet/server.py b/examples/quickstart-mxnet/server.py index 29cbce1884d1..871aa4e8ec99 100644 --- a/examples/quickstart-mxnet/server.py +++ b/examples/quickstart-mxnet/server.py @@ -1,6 +1,5 @@ """Flower server example.""" - import flwr as fl if __name__ == "__main__": diff --git a/examples/quickstart-pandas/README.md b/examples/quickstart-pandas/README.md index 2defc468c2ef..dd69f3ead3cb 100644 --- a/examples/quickstart-pandas/README.md +++ b/examples/quickstart-pandas/README.md @@ -1,6 +1,7 @@ # Flower Example using Pandas -This introductory example to Flower uses Pandas, but deep knowledge of Pandas is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. +This introductory example to Flower uses Pandas, but deep knowledge of Pandas is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) to +download, partition and preprocess the dataset. Running this example in itself is quite easy. ## Project Setup @@ -69,13 +70,13 @@ Now you are ready to start the Flower clients which will participate in the lear Start client 1 in the first terminal: ```shell -$ python3 client.py +$ python3 client.py --partition-id 0 ``` Start client 2 in the second terminal: ```shell -$ python3 client.py +$ python3 client.py --partition-id 1 ``` -You will see that the server is printing aggregated statistics about the dataset distributed amongst clients. Have a look to the [Flower Quickstarter documentation](https://flower.dev/docs/quickstart-pandas.html) for a detailed explanation. +You will see that the server is printing aggregated statistics about the dataset distributed amongst clients. Have a look to the [Flower Quickstarter documentation](https://flower.ai/docs/quickstart-pandas.html) for a detailed explanation. diff --git a/examples/quickstart-pandas/client.py b/examples/quickstart-pandas/client.py index 3feab3f6a0f4..c52b7c65b04c 100644 --- a/examples/quickstart-pandas/client.py +++ b/examples/quickstart-pandas/client.py @@ -1,4 +1,4 @@ -import warnings +import argparse from typing import Dict, List, Tuple import numpy as np @@ -6,10 +6,10 @@ import flwr as fl +from flwr_datasets import FederatedDataset -df = pd.read_csv("./data/client.csv") -column_names = ["sepal length (cm)", "sepal width (cm)"] +column_names = ["sepal_length", "sepal_width"] def compute_hist(df: pd.DataFrame, col_name: str) -> np.ndarray: @@ -19,23 +19,47 @@ def compute_hist(df: pd.DataFrame, col_name: str) -> np.ndarray: # Define Flower client class FlowerClient(fl.client.NumPyClient): + def __init__(self, X: pd.DataFrame): + self.X = X + def fit( self, parameters: List[np.ndarray], config: Dict[str, str] ) -> Tuple[List[np.ndarray], int, Dict]: hist_list = [] # Execute query locally - for c in column_names: - hist = compute_hist(df, c) + for c in self.X.columns: + hist = compute_hist(self.X, c) hist_list.append(hist) return ( hist_list, - len(df), + len(self.X), {}, ) -# Start Flower client -fl.client.start_numpy_client( - server_address="127.0.0.1:8080", - client=FlowerClient(), -) +if __name__ == "__main__": + N_CLIENTS = 2 + + parser = argparse.ArgumentParser(description="Flower") + parser.add_argument( + "--partition-id", + type=int, + choices=range(0, N_CLIENTS), + required=True, + help="Specifies the partition id of artificially partitioned datasets.", + ) + args = parser.parse_args() + partition_id = args.partition_id + + # Load the partition data + fds = FederatedDataset(dataset="hitorilabs/iris", partitioners={"train": N_CLIENTS}) + + dataset = fds.load_partition(partition_id, "train").with_format("pandas")[:] + # Use just the specified columns + X = dataset[column_names] + + # Start Flower client + fl.client.start_client( + server_address="127.0.0.1:8080", + client=FlowerClient(X).to_client(), + ) diff --git a/examples/quickstart-pandas/pyproject.toml b/examples/quickstart-pandas/pyproject.toml index de20eaf61d63..2e6b1424bb54 100644 --- a/examples/quickstart-pandas/pyproject.toml +++ b/examples/quickstart-pandas/pyproject.toml @@ -7,11 +7,11 @@ name = "quickstart-pandas" version = "0.1.0" description = "Pandas Federated Analytics Quickstart with Flower" authors = ["Ragy Haddad "] -maintainers = ["The Flower Authors "] +maintainers = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" flwr = ">=1.0,<2.0" +flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" } numpy = "1.23.2" pandas = "2.0.0" -scikit-learn = "1.3.1" diff --git a/examples/quickstart-pandas/requirements.txt b/examples/quickstart-pandas/requirements.txt index 14308a55faaf..d44a3c6adab9 100644 --- a/examples/quickstart-pandas/requirements.txt +++ b/examples/quickstart-pandas/requirements.txt @@ -1,4 +1,4 @@ flwr>=1.0, <2.0 +flwr-datasets[vision]>=0.0.2, <1.0.0 numpy==1.23.2 pandas==2.0.0 -scikit-learn==1.3.1 diff --git a/examples/quickstart-pandas/run.sh b/examples/quickstart-pandas/run.sh index 6b85ce30bf45..2ae1e582b8cf 100755 --- a/examples/quickstart-pandas/run.sh +++ b/examples/quickstart-pandas/run.sh @@ -2,13 +2,9 @@ echo "Starting server" python server.py & sleep 3 # Sleep for 3s to give the server enough time to start -# Download data -mkdir -p ./data -python -c "from sklearn.datasets import load_iris; load_iris(as_frame=True)['data'].to_csv('./data/client.csv')" - for i in `seq 0 1`; do echo "Starting client $i" - python client.py & + python client.py --partition-id ${i} & done # This will allow you to use CTRL+C to stop all background processes diff --git a/examples/quickstart-pandas/server.py b/examples/quickstart-pandas/server.py index c82304374836..af4c2a796788 100644 --- a/examples/quickstart-pandas/server.py +++ b/examples/quickstart-pandas/server.py @@ -1,5 +1,4 @@ -import pickle -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import numpy as np @@ -9,9 +8,6 @@ EvaluateRes, FitIns, FitRes, - Metrics, - MetricsAggregationFn, - NDArrays, Parameters, Scalar, ndarrays_to_parameters, @@ -23,11 +19,6 @@ class FedAnalytics(Strategy): - def __init__( - self, compute_fns: List[Callable] = None, col_names: List[str] = None - ) -> None: - super().__init__() - def initialize_parameters( self, client_manager: Optional[ClientManager] = None ) -> Optional[Parameters]: diff --git a/examples/quickstart-pytorch-lightning/README.md b/examples/quickstart-pytorch-lightning/README.md index 360efb8f6261..fb29c7e9e9ea 100644 --- a/examples/quickstart-pytorch-lightning/README.md +++ b/examples/quickstart-pytorch-lightning/README.md @@ -1 +1,76 @@ -# Flower Examples using PyTorch Lightning +# Flower Example using PyTorch Lightning + +This introductory example to Flower uses PyTorch, but deep knowledge of PyTorch Lightning is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. Running this example in itself is quite easy. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) to download, partition and preprocess the MNIST dataset. + +## Project Setup + +Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you: + +```shell +git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/quickstart-pytorch-lightning . && rm -rf flower && cd quickstart-pytorch-lightning +``` + +This will create a new directory called `quickstart-pytorch-lightning` containing the following files: + +```shell +-- pyproject.toml +-- requirements.txt +-- client.py # client-side code +-- server.py # server-side code (including the strategy) +-- README.md +-- run.sh # runs server, then two clients +-- mnist.py # run a centralised version of this example +``` + +### Installing Dependencies + +Project dependencies (such as `torch` and `flwr`) are defined in `pyproject.toml` and `requirements.txt`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)) or [pip](https://pip.pypa.io/en/latest/development/), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences. + +#### Poetry + +```shell +poetry install +poetry shell +``` + +Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command: + +```shell +poetry run python -c "import flwr" +``` + +If you don't see any errors you're good to go! + +#### pip + +Write the command below in your terminal to install the dependencies according to the configuration file requirements.txt. + +```shell +pip install -r requirements.txt +``` + +## Run Federated Learning with PyTorch and Flower + +Afterwards you are ready to start the Flower server as well as the clients. You can simply start the server in a terminal as follows: + +```shell +python server.py +``` + +Now you are ready to start the Flower clients which will participate in the learning. We need to specify the partition id to +use different partitions of the data on different nodes. To do so simply open two more terminal windows and run the +following commands. + +Start client 1 in the first terminal: + +```shell +python client.py --partition-id 0 +``` + +Start client 2 in the second terminal: + +```shell +python client.py --partition-id 1 +``` + +You will see that PyTorch is starting a federated training. Look at the [code](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch) for a detailed explanation. diff --git a/examples/quickstart-pytorch-lightning/client.py b/examples/quickstart-pytorch-lightning/client.py index e810d639974d..6e21259cc492 100644 --- a/examples/quickstart-pytorch-lightning/client.py +++ b/examples/quickstart-pytorch-lightning/client.py @@ -1,8 +1,14 @@ -import flwr as fl -import mnist -import pytorch_lightning as pl +import argparse from collections import OrderedDict + +import pytorch_lightning as pl import torch +from datasets.utils.logging import disable_progress_bar + +import flwr as fl +import mnist + +disable_progress_bar() class FlowerClient(fl.client.NumPyClient): @@ -50,13 +56,24 @@ def _set_parameters(model, parameters): def main() -> None: + parser = argparse.ArgumentParser(description="Flower") + parser.add_argument( + "--partition-id", + type=int, + choices=range(0, 10), + required=True, + help="Specifies the artificial data partition", + ) + args = parser.parse_args() + partition_id = args.partition_id + # Model and data model = mnist.LitAutoEncoder() - train_loader, val_loader, test_loader = mnist.load_data() + train_loader, val_loader, test_loader = mnist.load_data(partition_id) # Flower client - client = FlowerClient(model, train_loader, val_loader, test_loader) - fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=client) + client = FlowerClient(model, train_loader, val_loader, test_loader).to_client() + fl.client.start_client(server_address="127.0.0.1:8080", client=client) if __name__ == "__main__": diff --git a/examples/quickstart-pytorch-lightning/mnist.py b/examples/quickstart-pytorch-lightning/mnist.py index c8f8374ecc04..95342f4fb9b3 100644 --- a/examples/quickstart-pytorch-lightning/mnist.py +++ b/examples/quickstart-pytorch-lightning/mnist.py @@ -3,14 +3,13 @@ Source: pytorchlightning.ai (2021/02/04) """ - +from flwr_datasets import FederatedDataset +import pytorch_lightning as pl import torch from torch import nn from torch.nn import functional as F -from torch.utils.data import DataLoader, random_split +from torch.utils.data import DataLoader from torchvision import transforms -from torchvision.datasets import MNIST -import pytorch_lightning as pl class LitAutoEncoder(pl.LightningModule): @@ -60,25 +59,56 @@ def _evaluate(self, batch, stage=None): self.log(f"{stage}_loss", loss, prog_bar=True) -def load_data(): - # Training / validation set - trainset = MNIST("", train=True, download=True, transform=transforms.ToTensor()) - mnist_train, mnist_val = random_split(trainset, [55000, 5000]) - train_loader = DataLoader(mnist_train, batch_size=32, shuffle=True, num_workers=16) - val_loader = DataLoader(mnist_val, batch_size=32, shuffle=False, num_workers=16) +def collate_fn(batch): + """Change the dictionary to tuple to keep the exact dataloader behavior.""" + images = [item["image"] for item in batch] + labels = [item["label"] for item in batch] + + images_tensor = torch.stack(images) + labels_tensor = torch.tensor(labels) + + return images_tensor, labels_tensor + + +def apply_transforms(batch): + """Apply transforms to the partition from FederatedDataset.""" + batch["image"] = [transforms.functional.to_tensor(img) for img in batch["image"]] + return batch + - # Test set - testset = MNIST("", train=False, download=True, transform=transforms.ToTensor()) - test_loader = DataLoader(testset, batch_size=32, shuffle=False, num_workers=16) +def load_data(partition): + fds = FederatedDataset(dataset="mnist", partitioners={"train": 10}) + partition = fds.load_partition(partition, "train") - return train_loader, val_loader, test_loader + partition = partition.with_transform(apply_transforms) + # 20 % for on federated evaluation + partition_full = partition.train_test_split(test_size=0.2) + # 60 % for the federated train and 20 % for the federated validation (both in fit) + partition_train_valid = partition_full["train"].train_test_split(train_size=0.75) + trainloader = DataLoader( + partition_train_valid["train"], + batch_size=32, + shuffle=True, + collate_fn=collate_fn, + num_workers=1, + ) + valloader = DataLoader( + partition_train_valid["test"], + batch_size=32, + collate_fn=collate_fn, + num_workers=1, + ) + testloader = DataLoader( + partition_full["test"], batch_size=32, collate_fn=collate_fn, num_workers=1 + ) + return trainloader, valloader, testloader def main() -> None: """Centralized training.""" # Load data - train_loader, val_loader, test_loader = load_data() + train_loader, val_loader, test_loader = load_data(0) # Load model model = LitAutoEncoder() diff --git a/examples/quickstart-pytorch-lightning/pyproject.toml b/examples/quickstart-pytorch-lightning/pyproject.toml index 0a1e1376b8cb..a09aaa3d65b5 100644 --- a/examples/quickstart-pytorch-lightning/pyproject.toml +++ b/examples/quickstart-pytorch-lightning/pyproject.toml @@ -6,11 +6,12 @@ build-backend = "poetry.masonry.api" name = "quickstart-pytorch-lightning" version = "0.1.0" description = "Federated Learning Quickstart with Flower and PyTorch Lightning" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = "^3.8" flwr = ">=1.0,<2.0" # flwr = { path = "../../", develop = true } # Development +flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" } pytorch-lightning = "1.6.0" torchvision = "0.14.1" diff --git a/examples/quickstart-pytorch-lightning/requirements.txt b/examples/quickstart-pytorch-lightning/requirements.txt index 1cd0b31fa0b5..6530dcc8c52c 100644 --- a/examples/quickstart-pytorch-lightning/requirements.txt +++ b/examples/quickstart-pytorch-lightning/requirements.txt @@ -1,3 +1,4 @@ flwr>=1.0, <2.0 +flwr-datasets[vision]>=0.0.2, <1.0.0 pytorch_lightning>=1.4.7 torchvision==0.14.1 diff --git a/examples/quickstart-pytorch-lightning/run.sh b/examples/quickstart-pytorch-lightning/run.sh index 2b6507bc154c..62a1dac199bd 100755 --- a/examples/quickstart-pytorch-lightning/run.sh +++ b/examples/quickstart-pytorch-lightning/run.sh @@ -4,9 +4,9 @@ echo "Starting server" python server.py & sleep 3 # Sleep for 3s to give the server enough time to start -for i in `seq 0 1`; do +for i in $(seq 0 1); do echo "Starting client $i" - python client.py & + python client.py --partition-id "${i}" & done # This will allow you to use CTRL+C to stop all background processes diff --git a/examples/quickstart-pytorch-lightning/server.py b/examples/quickstart-pytorch-lightning/server.py index 370186ae1d98..a104a1fffd26 100644 --- a/examples/quickstart-pytorch-lightning/server.py +++ b/examples/quickstart-pytorch-lightning/server.py @@ -11,7 +11,7 @@ def main() -> None: # Start Flower server for three rounds of federated learning fl.server.start_server( server_address="0.0.0.0:8080", - config=fl.server.ServerConfig(num_rounds=10), + config=fl.server.ServerConfig(num_rounds=3), strategy=strategy, ) diff --git a/examples/quickstart-pytorch/README.md b/examples/quickstart-pytorch/README.md index 6de0dcf7ab32..02c9b4b38498 100644 --- a/examples/quickstart-pytorch/README.md +++ b/examples/quickstart-pytorch/README.md @@ -1,6 +1,6 @@ # Flower Example using PyTorch -This introductory example to Flower uses PyTorch, but deep knowledge of PyTorch is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. Running this example in itself is quite easy. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) to download, partition and preprocess the CIFAR-10 dataset. +This introductory example to Flower uses PyTorch, but deep knowledge of PyTorch is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. Running this example in itself is quite easy. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) to download, partition and preprocess the CIFAR-10 dataset. ## Project Setup @@ -55,20 +55,20 @@ Afterwards you are ready to start the Flower server as well as the clients. You python3 server.py ``` -Now you are ready to start the Flower clients which will participate in the learning. We need to specify the node id to +Now you are ready to start the Flower clients which will participate in the learning. We need to specify the partition id to use different partitions of the data on different nodes. To do so simply open two more terminal windows and run the following commands. Start client 1 in the first terminal: ```shell -python3 client.py --node-id 0 +python3 client.py --partition-id 0 ``` Start client 2 in the second terminal: ```shell -python3 client.py --node-id 1 +python3 client.py --partition-id 1 ``` You will see that PyTorch is starting a federated training. Look at the [code](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch) for a detailed explanation. diff --git a/examples/quickstart-pytorch/client.py b/examples/quickstart-pytorch/client.py index ad57645002f8..e640ce111dff 100644 --- a/examples/quickstart-pytorch/client.py +++ b/examples/quickstart-pytorch/client.py @@ -69,10 +69,10 @@ def test(net, testloader): return loss, accuracy -def load_data(node_id): +def load_data(partition_id): """Load partition CIFAR10 data.""" fds = FederatedDataset(dataset="cifar10", partitioners={"train": 3}) - partition = fds.load_partition(node_id) + partition = fds.load_partition(partition_id) # Divide data on each node: 80% train, 20% test partition_train_test = partition.train_test_split(test_size=0.2) pytorch_transforms = Compose( @@ -94,19 +94,20 @@ def apply_transforms(batch): # 2. Federation of the pipeline with Flower # ############################################################################# -# Get node id +# Get partition id parser = argparse.ArgumentParser(description="Flower") parser.add_argument( - "--node-id", + "--partition-id", choices=[0, 1, 2], + required=True, type=int, help="Partition of the dataset divided into 3 iid partitions created artificially.", ) -node_id = parser.parse_args().node_id +partition_id = parser.parse_args().partition_id # Load model and data (simple CNN, CIFAR-10) net = Net().to(DEVICE) -trainloader, testloader = load_data(node_id=node_id) +trainloader, testloader = load_data(partition_id=partition_id) # Define Flower client @@ -131,7 +132,7 @@ def evaluate(self, parameters, config): # Start Flower client -fl.client.start_numpy_client( +fl.client.start_client( server_address="127.0.0.1:8080", - client=FlowerClient(), + client=FlowerClient().to_client(), ) diff --git a/examples/quickstart-pytorch/pyproject.toml b/examples/quickstart-pytorch/pyproject.toml index ec6a3af8c5b4..d8e1503dd8a7 100644 --- a/examples/quickstart-pytorch/pyproject.toml +++ b/examples/quickstart-pytorch/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" name = "quickstart-pytorch" version = "0.1.0" description = "PyTorch Federated Learning Quickstart with Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" diff --git a/examples/quickstart-pytorch/run.sh b/examples/quickstart-pytorch/run.sh index cdace99bb8df..6ca9c8cafec9 100755 --- a/examples/quickstart-pytorch/run.sh +++ b/examples/quickstart-pytorch/run.sh @@ -8,7 +8,7 @@ sleep 3 # Sleep for 3s to give the server enough time to start for i in $(seq 0 1); do echo "Starting client $i" - python client.py --node-id "$i" & + python client.py --partition-id "$i" & done # Enable CTRL+C to stop all background processes diff --git a/examples/quickstart-sklearn-tabular/README.md b/examples/quickstart-sklearn-tabular/README.md new file mode 100644 index 000000000000..a975a9392800 --- /dev/null +++ b/examples/quickstart-sklearn-tabular/README.md @@ -0,0 +1,77 @@ +# Flower Example using scikit-learn + +This example of Flower uses `scikit-learn`'s `LogisticRegression` model to train a federated learning system on +"iris" (tabular) dataset. +It will help you understand how to adapt Flower for use with `scikit-learn`. +Running this example in itself is quite easy. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) to +download, partition and preprocess the dataset. + +## Project Setup + +Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you: + +```shell +git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/quickstart-sklearn-tabular . && rm -rf flower && cd quickstart-sklearn-tabular +``` + +This will create a new directory called `quickstart-sklearn-tabular` containing the following files: + +```shell +-- pyproject.toml +-- requirements.txt +-- client.py +-- server.py +-- utils.py +-- README.md +``` + +### Installing Dependencies + +Project dependencies (such as `scikit-learn` and `flwr`) are defined in `pyproject.toml` and `requirements.txt`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)) or [pip](https://pip.pypa.io/en/latest/development/), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences. + +#### Poetry + +```shell +poetry install +poetry shell +``` + +Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command: + +```shell +poetry run python3 -c "import flwr" +``` + +If you don't see any errors you're good to go! + +#### pip + +Write the command below in your terminal to install the dependencies according to the configuration file requirements.txt. + +```shell +pip install -r requirements.txt +``` + +## Run Federated Learning with scikit-learn and Flower + +Afterwards you are ready to start the Flower server as well as the clients. You can simply start the server in a terminal as follows: + +```shell +poetry run python3 server.py +``` + +Now you are ready to start the Flower clients which will participate in the learning. To do so simply open two more terminals and run the following command in each: + +```shell +poetry run python3 client.py --partition-id 0 # partition-id should be any of {0,1,2} +``` + +Alternatively you can run all of it in one shell as follows: + +```shell +poetry run python3 server.py & +poetry run python3 client.py --partition-id 0 & +poetry run python3 client.py --partition-id 1 +``` + +You will see that Flower is starting a federated training. diff --git a/examples/quickstart-sklearn-tabular/client.py b/examples/quickstart-sklearn-tabular/client.py new file mode 100644 index 000000000000..b7e3046c822d --- /dev/null +++ b/examples/quickstart-sklearn-tabular/client.py @@ -0,0 +1,73 @@ +import argparse +import warnings + +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import log_loss + +import flwr as fl +import utils +from flwr_datasets import FederatedDataset + +if __name__ == "__main__": + N_CLIENTS = 3 + + parser = argparse.ArgumentParser(description="Flower") + parser.add_argument( + "--partition-id", + type=int, + choices=range(0, N_CLIENTS), + required=True, + help="Specifies the artificial data partition", + ) + args = parser.parse_args() + partition_id = args.partition_id + + # Load the partition data + fds = FederatedDataset(dataset="hitorilabs/iris", partitioners={"train": N_CLIENTS}) + + dataset = fds.load_partition(partition_id, "train").with_format("pandas")[:] + X = dataset[["petal_length", "petal_width", "sepal_length", "sepal_width"]] + y = dataset["species"] + unique_labels = fds.load_split("train").unique("species") + # Split the on edge data: 80% train, 20% test + X_train, X_test = X[: int(0.8 * len(X))], X[int(0.8 * len(X)) :] + y_train, y_test = y[: int(0.8 * len(y))], y[int(0.8 * len(y)) :] + + # Create LogisticRegression Model + model = LogisticRegression( + penalty="l2", + max_iter=1, # local epoch + warm_start=True, # prevent refreshing weights when fitting + ) + + # Setting initial parameters, akin to model.compile for keras models + utils.set_initial_params(model, n_features=X_train.shape[1], n_classes=3) + + # Define Flower client + class IrisClient(fl.client.NumPyClient): + def get_parameters(self, config): # type: ignore + return utils.get_model_parameters(model) + + def fit(self, parameters, config): # type: ignore + utils.set_model_params(model, parameters) + # Ignore convergence failure due to low local epochs + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + model.fit(X_train, y_train) + accuracy = model.score(X_train, y_train) + return ( + utils.get_model_parameters(model), + len(X_train), + {"train_accuracy": accuracy}, + ) + + def evaluate(self, parameters, config): # type: ignore + utils.set_model_params(model, parameters) + loss = log_loss(y_test, model.predict_proba(X_test), labels=unique_labels) + accuracy = model.score(X_test, y_test) + return loss, len(X_test), {"test_accuracy": accuracy} + + # Start Flower client + fl.client.start_client( + server_address="0.0.0.0:8080", client=IrisClient().to_client() + ) diff --git a/examples/quickstart-sklearn-tabular/pyproject.toml b/examples/quickstart-sklearn-tabular/pyproject.toml new file mode 100644 index 000000000000..86eab5c38df0 --- /dev/null +++ b/examples/quickstart-sklearn-tabular/pyproject.toml @@ -0,0 +1,18 @@ +[build-system] +requires = ["poetry-core>=1.4.0"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "sklearn-mnist" +version = "0.1.0" +description = "Federated learning with scikit-learn and Flower" +authors = [ + "The Flower Authors ", + "Kaushik Amar Das ", +] + +[tool.poetry.dependencies] +python = "^3.8" +flwr = ">=1.0,<2.0" +flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" } +scikit-learn = "^1.3.0" diff --git a/examples/quickstart-sklearn-tabular/requirements.txt b/examples/quickstart-sklearn-tabular/requirements.txt new file mode 100644 index 000000000000..e0f15b31f3f7 --- /dev/null +++ b/examples/quickstart-sklearn-tabular/requirements.txt @@ -0,0 +1,3 @@ +flwr>=1.0, <2.0 +flwr-datasets[vision]>=0.0.2, <1.0.0 +scikit-learn>=1.3.0 diff --git a/examples/quickstart-sklearn-tabular/run.sh b/examples/quickstart-sklearn-tabular/run.sh new file mode 100755 index 000000000000..f770ca05f8f4 --- /dev/null +++ b/examples/quickstart-sklearn-tabular/run.sh @@ -0,0 +1,17 @@ +#!/bin/bash +set -e +cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/ + +echo "Starting server" +python server.py & +sleep 3 # Sleep for 3s to give the server enough time to start + +for i in $(seq 0 1); do + echo "Starting client $i" + python client.py --partition-id "${i}" & +done + +# This will allow you to use CTRL+C to stop all background processes +trap 'trap - SIGTERM && kill -- -$$' SIGINT SIGTERM +# Wait for all background processes to complete +wait diff --git a/examples/quickstart-sklearn-tabular/server.py b/examples/quickstart-sklearn-tabular/server.py new file mode 100644 index 000000000000..0c779c52a8d6 --- /dev/null +++ b/examples/quickstart-sklearn-tabular/server.py @@ -0,0 +1,19 @@ +import flwr as fl +import utils +from sklearn.linear_model import LogisticRegression + + +# Start Flower server for five rounds of federated learning +if __name__ == "__main__": + model = LogisticRegression() + utils.set_initial_params(model, n_classes=3, n_features=4) + strategy = fl.server.strategy.FedAvg( + min_available_clients=2, + fit_metrics_aggregation_fn=utils.weighted_average, + evaluate_metrics_aggregation_fn=utils.weighted_average, + ) + fl.server.start_server( + server_address="0.0.0.0:8080", + strategy=strategy, + config=fl.server.ServerConfig(num_rounds=25), + ) diff --git a/examples/quickstart-sklearn-tabular/utils.py b/examples/quickstart-sklearn-tabular/utils.py new file mode 100644 index 000000000000..e154f44ef8bf --- /dev/null +++ b/examples/quickstart-sklearn-tabular/utils.py @@ -0,0 +1,75 @@ +from typing import List, Tuple, Dict + +import numpy as np +from sklearn.linear_model import LogisticRegression + +from flwr.common import NDArrays, Metrics, Scalar + + +def get_model_parameters(model: LogisticRegression) -> NDArrays: + """Return the parameters of a sklearn LogisticRegression model.""" + if model.fit_intercept: + params = [ + model.coef_, + model.intercept_, + ] + else: + params = [ + model.coef_, + ] + return params + + +def set_model_params(model: LogisticRegression, params: NDArrays) -> LogisticRegression: + """Set the parameters of a sklean LogisticRegression model.""" + model.coef_ = params[0] + if model.fit_intercept: + model.intercept_ = params[1] + return model + + +def set_initial_params(model: LogisticRegression, n_classes: int, n_features: int): + """Set initial parameters as zeros. + + Required since model params are uninitialized until model.fit is called but server + asks for initial parameters from clients at launch. Refer to + sklearn.linear_model.LogisticRegression documentation for more information. + """ + model.classes_ = np.array([i for i in range(n_classes)]) + + model.coef_ = np.zeros((n_classes, n_features)) + if model.fit_intercept: + model.intercept_ = np.zeros((n_classes,)) + + +def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Dict[str, Scalar]: + """Compute weighted average. + + It is generic implementation that averages only over floats and ints and drops the + other data types of the Metrics. + """ + print(metrics) + # num_samples_list can represent number of sample or batches depending on the client + num_samples_list = [n_batches for n_batches, _ in metrics] + num_samples_sum = sum(num_samples_list) + metrics_lists: Dict[str, List[float]] = {} + for num_samples, all_metrics_dict in metrics: + # Calculate each metric one by one + for single_metric, value in all_metrics_dict.items(): + if isinstance(value, (float, int)): + metrics_lists[single_metric] = [] + # Just one iteration needed to initialize the keywords + break + + for num_samples, all_metrics_dict in metrics: + # Calculate each metric one by one + for single_metric, value in all_metrics_dict.items(): + # Add weighted metric + if isinstance(value, (float, int)): + metrics_lists[single_metric].append(float(num_samples * value)) + + weighted_metrics: Dict[str, Scalar] = {} + for metric_name, metric_values in metrics_lists.items(): + weighted_metrics[metric_name] = sum(metric_values) / num_samples_sum + + return weighted_metrics diff --git a/examples/quickstart-tabnet/client.py b/examples/quickstart-tabnet/client.py index da391a95710a..2289b1b55b3d 100644 --- a/examples/quickstart-tabnet/client.py +++ b/examples/quickstart-tabnet/client.py @@ -79,4 +79,6 @@ def evaluate(self, parameters, config): # Start Flower client -fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=TabNetClient()) +fl.client.start_client( + server_address="127.0.0.1:8080", client=TabNetClient().to_client() +) diff --git a/examples/quickstart-tabnet/pyproject.toml b/examples/quickstart-tabnet/pyproject.toml index 948eaf085b86..18f1979791bd 100644 --- a/examples/quickstart-tabnet/pyproject.toml +++ b/examples/quickstart-tabnet/pyproject.toml @@ -6,12 +6,12 @@ build-backend = "poetry.core.masonry.api" name = "quickstart-tabnet" version = "0.1.0" description = "Tabnet Federated Learning Quickstart with Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" flwr = ">=1.0,<2.0" -tensorflow-cpu = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\""} -tensorflow-macos = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\""} +tensorflow-cpu = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\"" } +tensorflow-macos = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\"" } tensorflow_datasets = "4.8.3" tabnet = "0.1.6" diff --git a/examples/quickstart-tensorflow/README.md b/examples/quickstart-tensorflow/README.md index 7ada48797d03..8d5e9434b086 100644 --- a/examples/quickstart-tensorflow/README.md +++ b/examples/quickstart-tensorflow/README.md @@ -1,7 +1,7 @@ # Flower Example using TensorFlow/Keras -This introductory example to Flower uses Keras but deep knowledge of Keras is not necessarily required to run the example. However, it will help you understanding how to adapt Flower to your use-cases. -Running this example in itself is quite easy. +This introductory example to Flower uses Keras but deep knowledge of Keras is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. +Running this example in itself is quite easy. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) to download, partition and preprocess the CIFAR-10 dataset. ## Project Setup @@ -50,7 +50,7 @@ pip install -r requirements.txt ## Run Federated Learning with TensorFlow/Keras and Flower -Afterwards you are ready to start the Flower server as well as the clients. You can simply start the server in a terminal as follows: +Afterward, you are ready to start the Flower server as well as the clients. You can simply start the server in a terminal as follows: ```shell poetry run python3 server.py @@ -62,7 +62,7 @@ Now you are ready to start the Flower clients which will participate in the lear poetry run python3 client.py ``` -Alternatively you can run all of it in one shell as follows: +Alternatively, you can run all of it in one shell as follows: ```shell poetry run python3 server.py & diff --git a/examples/quickstart-tensorflow/client.py b/examples/quickstart-tensorflow/client.py index fc367e2c3053..3e2035c09311 100644 --- a/examples/quickstart-tensorflow/client.py +++ b/examples/quickstart-tensorflow/client.py @@ -1,16 +1,38 @@ +import argparse import os import flwr as fl import tensorflow as tf - +from flwr_datasets import FederatedDataset # Make TensorFlow log less verbose os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +# Parse arguments +parser = argparse.ArgumentParser(description="Flower") +parser.add_argument( + "--partition-id", + type=int, + choices=[0, 1, 2], + required=True, + help="Partition of the dataset (0,1 or 2). " + "The dataset is divided into 3 partitions created artificially.", +) +args = parser.parse_args() + # Load model and data (MobileNetV2, CIFAR-10) model = tf.keras.applications.MobileNetV2((32, 32, 3), classes=10, weights=None) model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"]) -(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data() + +# Download and partition dataset +fds = FederatedDataset(dataset="cifar10", partitioners={"train": 3}) +partition = fds.load_partition(args.partition_id, "train") +partition.set_format("numpy") + +# Divide data on each node: 80% train, 20% test +partition = partition.train_test_split(test_size=0.2) +x_train, y_train = partition["train"]["img"] / 255.0, partition["train"]["label"] +x_test, y_test = partition["test"]["img"] / 255.0, partition["test"]["label"] # Define Flower client @@ -30,4 +52,6 @@ def evaluate(self, parameters, config): # Start Flower client -fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=CifarClient()) +fl.client.start_client( + server_address="127.0.0.1:8080", client=CifarClient().to_client() +) diff --git a/examples/quickstart-tensorflow/pyproject.toml b/examples/quickstart-tensorflow/pyproject.toml index 68d4f9aada52..98aeb932cab9 100644 --- a/examples/quickstart-tensorflow/pyproject.toml +++ b/examples/quickstart-tensorflow/pyproject.toml @@ -6,10 +6,11 @@ build-backend = "poetry.core.masonry.api" name = "quickstart-tensorflow" version = "0.1.0" description = "Keras Federated Learning Quickstart with Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" flwr = ">=1.0,<2.0" -tensorflow-cpu = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\""} -tensorflow-macos = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\""} +flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" } +tensorflow-cpu = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\"" } +tensorflow-macos = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\"" } diff --git a/examples/quickstart-tensorflow/requirements.txt b/examples/quickstart-tensorflow/requirements.txt index 6420aab25ec8..7f025975cae9 100644 --- a/examples/quickstart-tensorflow/requirements.txt +++ b/examples/quickstart-tensorflow/requirements.txt @@ -1,3 +1,4 @@ flwr>=1.0, <2.0 +flwr-datasets[vision]>=0.0.2, <1.0.0 tensorflow-macos>=2.9.1, != 2.11.1 ; sys_platform == "darwin" and platform_machine == "arm64" tensorflow-cpu>=2.9.1, != 2.11.1 ; platform_machine == "x86_64" diff --git a/examples/quickstart-tensorflow/run.sh b/examples/quickstart-tensorflow/run.sh index c64f362086aa..76188f197e3e 100755 --- a/examples/quickstart-tensorflow/run.sh +++ b/examples/quickstart-tensorflow/run.sh @@ -6,7 +6,7 @@ sleep 3 # Sleep for 3s to give the server enough time to start for i in `seq 0 1`; do echo "Starting client $i" - python client.py & + python client.py --partition-id $i & done # This will allow you to use CTRL+C to stop all background processes diff --git a/examples/quickstart-tensorflow/server.py b/examples/quickstart-tensorflow/server.py index 39c350388c1b..fe691a88aba0 100644 --- a/examples/quickstart-tensorflow/server.py +++ b/examples/quickstart-tensorflow/server.py @@ -1,8 +1,25 @@ +from typing import List, Tuple + import flwr as fl +from flwr.common import Metrics + + +# Define metric aggregation function +def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: + # Multiply accuracy of each client by number of examples used + accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics] + examples = [num_examples for num_examples, _ in metrics] + + # Aggregate and return custom metric (weighted average) + return {"accuracy": sum(accuracies) / sum(examples)} + +# Define strategy +strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=weighted_average) # Start Flower server fl.server.start_server( server_address="0.0.0.0:8080", config=fl.server.ServerConfig(num_rounds=3), + strategy=strategy, ) diff --git a/examples/quickstart-xgboost-horizontal/.gitignore b/examples/quickstart-xgboost-horizontal/.gitignore deleted file mode 100644 index 4a6ddf5b9142..000000000000 --- a/examples/quickstart-xgboost-horizontal/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -dataset - diff --git a/examples/quickstart-xgboost-horizontal/README.md b/examples/quickstart-xgboost-horizontal/README.md deleted file mode 100644 index 346a33da7412..000000000000 --- a/examples/quickstart-xgboost-horizontal/README.md +++ /dev/null @@ -1,19 +0,0 @@ -# Federated XGBoost in Horizontal Setting (PyTorch) - -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adap/flower/blob/main/examples/quickstart-xgboost-horizontal/code_horizontal.ipynb) (or open the [Jupyter Notebook](https://github.com/adap/flower/blob/main/examples/quickstart-xgboost-horizontal/code_horizontal.ipynb)) - -This example demonstrates a federated XGBoost using Flower with PyTorch. This is a novel method to conduct federated XGBoost in the horizontal setting. It differs from the previous methods in the following ways: - -- We aggregate and conduct federated learning on client tree’s prediction outcomes by sending clients' built XGBoost trees to the server and then sharing to the clients. -- The exchange of privacy-sensitive information (gradients) is not needed. -- The model is a CNN with 1D convolution kernel size = the number of XGBoost trees in the client tree ensembles. -- Using 1D convolution, we make the tree learning rate (a hyperparameter of XGBoost) learnable. - -## Project Setup - -This implementation can be easily run in Google Colab with the button at the top of the README or as a standalone Jupyter notebook, -it will automatically download and extract the example data inside a `dataset` folder and `binary_classification` and `regression` sub-folders. - -## Datasets - -This implementation supports both binary classification and regression datasets in SVM light format, loaded from ([LIBSVM Data](https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/)). Simply download the dataset files from the website and put them in the folder location indicated above. diff --git a/examples/quickstart-xgboost-horizontal/code_horizontal.ipynb b/examples/quickstart-xgboost-horizontal/code_horizontal.ipynb deleted file mode 100644 index 4d76e0c26023..000000000000 --- a/examples/quickstart-xgboost-horizontal/code_horizontal.ipynb +++ /dev/null @@ -1,1560 +0,0 @@ -{ - "cells": [ - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Initialization" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "executionInfo": { - "elapsed": 15871, - "status": "ok", - "timestamp": 1670356049976, - "user": { - "displayName": "Chenyang Ma", - "userId": "17975430055716133031" - }, - "user_tz": 0 - }, - "outputId": "2c588ea0-a383-4461-e633-794e73d0f57a" - }, - "outputs": [], - "source": [ - "import os\n", - "import urllib.request\n", - "import bz2\n", - "import shutil\n", - "\n", - "CLASSIFICATION_PATH = os.path.join(\"dataset\", \"binary_classification\")\n", - "REGRESSION_PATH = os.path.join(\"dataset\", \"regression\")\n", - "\n", - "if not os.path.exists(CLASSIFICATION_PATH):\n", - " os.makedirs(CLASSIFICATION_PATH)\n", - " urllib.request.urlretrieve(\n", - " \"https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/cod-rna\",\n", - " f\"{os.path.join(CLASSIFICATION_PATH, 'cod-rna')}\",\n", - " )\n", - " urllib.request.urlretrieve(\n", - " \"https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/cod-rna.t\",\n", - " f\"{os.path.join(CLASSIFICATION_PATH, 'cod-rna.t')}\",\n", - " )\n", - " urllib.request.urlretrieve(\n", - " \"https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/cod-rna.r\",\n", - " f\"{os.path.join(CLASSIFICATION_PATH, 'cod-rna.r')}\",\n", - " )\n", - " urllib.request.urlretrieve(\n", - " \"https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/ijcnn1.t.bz2\",\n", - " f\"{os.path.join(CLASSIFICATION_PATH, 'ijcnn1.t.bz2')}\",\n", - " )\n", - " urllib.request.urlretrieve(\n", - " \"https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/ijcnn1.tr.bz2\",\n", - " f\"{os.path.join(CLASSIFICATION_PATH, 'ijcnn1.tr.bz2')}\",\n", - " )\n", - " for filepath in os.listdir(CLASSIFICATION_PATH):\n", - " if filepath[-3:] == \"bz2\":\n", - " abs_filepath = os.path.join(CLASSIFICATION_PATH, filepath)\n", - " with bz2.BZ2File(abs_filepath) as fr, open(abs_filepath[:-4], \"wb\") as fw:\n", - " shutil.copyfileobj(fr, fw)\n", - "\n", - "if not os.path.exists(REGRESSION_PATH):\n", - " os.makedirs(REGRESSION_PATH)\n", - " urllib.request.urlretrieve(\n", - " \"https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/regression/eunite2001\",\n", - " f\"{os.path.join(REGRESSION_PATH, 'eunite2001')}\",\n", - " )\n", - " urllib.request.urlretrieve(\n", - " \"https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/regression/eunite2001.t\",\n", - " f\"{os.path.join(REGRESSION_PATH, 'eunite2001.t')}\",\n", - " )\n", - " urllib.request.urlretrieve(\n", - " \"https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/regression/YearPredictionMSD.bz2\",\n", - " f\"{os.path.join(REGRESSION_PATH, 'YearPredictionMSD.bz2')}\",\n", - " )\n", - " urllib.request.urlretrieve(\n", - " \"https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/regression/YearPredictionMSD.t.bz2\",\n", - " f\"{os.path.join(REGRESSION_PATH, 'YearPredictionMSD.t.bz2')}\",\n", - " )\n", - " for filepath in os.listdir(REGRESSION_PATH):\n", - " if filepath[-3:] == \"bz2\":\n", - " abs_filepath = os.path.join(REGRESSION_PATH, filepath)\n", - " with bz2.BZ2File(abs_filepath) as fr, open(abs_filepath[:-4], \"wb\") as fw:\n", - " shutil.copyfileobj(fr, fw)\n", - "\n", - "\n", - "!nvidia-smi\n", - "!pip install matplotlib scikit-learn tqdm torch torchmetrics torchsummary xgboost\n", - "!pip install -U \"flwr-nightly[simulation]\"" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Import relevant modules" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "executionInfo": { - "elapsed": 7, - "status": "ok", - "timestamp": 1670356049977, - "user": { - "displayName": "Chenyang Ma", - "userId": "17975430055716133031" - }, - "user_tz": 0 - }, - "outputId": "5289e33e-e18e-491b-d536-6b1052598994" - }, - "outputs": [], - "source": [ - "import xgboost as xgb\n", - "from xgboost import XGBClassifier, XGBRegressor\n", - "from sklearn.metrics import mean_squared_error, accuracy_score\n", - "from sklearn.datasets import load_svmlight_file\n", - "\n", - "import numpy as np\n", - "import torch, torch.nn as nn\n", - "import torch.nn.functional as F\n", - "import torchvision\n", - "from torchmetrics import Accuracy, MeanSquaredError\n", - "from tqdm import trange, tqdm\n", - "from torchsummary import summary\n", - "from torch.utils.data import DataLoader, Dataset, random_split\n", - "\n", - "print(\"Imported modules.\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Import Flower relevant modules for Federated XGBoost" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import flwr as fl\n", - "from flwr.common.typing import Parameters\n", - "from collections import OrderedDict\n", - "from typing import Any, Dict, List, Optional, Tuple, Union\n", - "from flwr.common import NDArray, NDArrays\n", - "\n", - "print(\"Imported modules.\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Define utility function for xgboost trees" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from matplotlib import pyplot as plt # pylint: disable=E0401\n", - "\n", - "\n", - "def plot_xgbtree(tree: Union[XGBClassifier, XGBRegressor], n_tree: int) -> None:\n", - " \"\"\"Visualize the built xgboost tree.\"\"\"\n", - " xgb.plot_tree(tree, num_trees=n_tree)\n", - " plt.rcParams[\"figure.figsize\"] = [50, 10]\n", - " plt.show()\n", - "\n", - "\n", - "def construct_tree(\n", - " dataset: Dataset, label: NDArray, n_estimators: int, tree_type: str\n", - ") -> Union[XGBClassifier, XGBRegressor]:\n", - " \"\"\"Construct a xgboost tree form tabular dataset.\"\"\"\n", - " if tree_type == \"BINARY\":\n", - " tree = xgb.XGBClassifier(\n", - " objective=\"binary:logistic\",\n", - " learning_rate=0.1,\n", - " max_depth=8,\n", - " n_estimators=n_estimators,\n", - " subsample=0.8,\n", - " colsample_bylevel=1,\n", - " colsample_bynode=1,\n", - " colsample_bytree=1,\n", - " alpha=5,\n", - " gamma=5,\n", - " num_parallel_tree=1,\n", - " min_child_weight=1,\n", - " )\n", - "\n", - " elif tree_type == \"REG\":\n", - " tree = xgb.XGBRegressor(\n", - " objective=\"reg:squarederror\",\n", - " learning_rate=0.1,\n", - " max_depth=8,\n", - " n_estimators=n_estimators,\n", - " subsample=0.8,\n", - " colsample_bylevel=1,\n", - " colsample_bynode=1,\n", - " colsample_bytree=1,\n", - " alpha=5,\n", - " gamma=5,\n", - " num_parallel_tree=1,\n", - " min_child_weight=1,\n", - " )\n", - "\n", - " tree.fit(dataset, label)\n", - " return tree\n", - "\n", - "\n", - "def construct_tree_from_loader(\n", - " dataset_loader: DataLoader, n_estimators: int, tree_type: str\n", - ") -> Union[XGBClassifier, XGBRegressor]:\n", - " \"\"\"Construct a xgboost tree form tabular dataset loader.\"\"\"\n", - " for dataset in dataset_loader:\n", - " data, label = dataset[0], dataset[1]\n", - " return construct_tree(data, label, n_estimators, tree_type)\n", - "\n", - "\n", - "def single_tree_prediction(\n", - " tree: Union[XGBClassifier, XGBRegressor], n_tree: int, dataset: NDArray\n", - ") -> Optional[NDArray]:\n", - " \"\"\"Extract the prediction result of a single tree in the xgboost tree\n", - " ensemble.\"\"\"\n", - " # How to access a single tree\n", - " # https://github.com/bmreiniger/datascience.stackexchange/blob/master/57905.ipynb\n", - " num_t = len(tree.get_booster().get_dump())\n", - " if n_tree > num_t:\n", - " print(\n", - " \"The tree index to be extracted is larger than the total number of trees.\"\n", - " )\n", - " return None\n", - "\n", - " return tree.predict( # type: ignore\n", - " dataset, iteration_range=(n_tree, n_tree + 1), output_margin=True\n", - " )\n", - "\n", - "\n", - "def tree_encoding( # pylint: disable=R0914\n", - " trainloader: DataLoader,\n", - " client_trees: Union[\n", - " Tuple[XGBClassifier, int],\n", - " Tuple[XGBRegressor, int],\n", - " List[Union[Tuple[XGBClassifier, int], Tuple[XGBRegressor, int]]],\n", - " ],\n", - " client_tree_num: int,\n", - " client_num: int,\n", - ") -> Optional[Tuple[NDArray, NDArray]]:\n", - " \"\"\"Transform the tabular dataset into prediction results using the\n", - " aggregated xgboost tree ensembles from all clients.\"\"\"\n", - " if trainloader is None:\n", - " return None\n", - "\n", - " for local_dataset in trainloader:\n", - " x_train, y_train = local_dataset[0], local_dataset[1]\n", - "\n", - " x_train_enc = np.zeros((x_train.shape[0], client_num * client_tree_num))\n", - " x_train_enc = np.array(x_train_enc, copy=True)\n", - "\n", - " temp_trees: Any = None\n", - " if isinstance(client_trees, list) is False:\n", - " temp_trees = [client_trees[0]] * client_num\n", - " elif isinstance(client_trees, list) and len(client_trees) != client_num:\n", - " temp_trees = [client_trees[0][0]] * client_num\n", - " else:\n", - " cids = []\n", - " temp_trees = []\n", - " for i, _ in enumerate(client_trees):\n", - " temp_trees.append(client_trees[i][0]) # type: ignore\n", - " cids.append(client_trees[i][1]) # type: ignore\n", - " sorted_index = np.argsort(np.asarray(cids))\n", - " temp_trees = np.asarray(temp_trees)[sorted_index]\n", - "\n", - " for i, _ in enumerate(temp_trees):\n", - " for j in range(client_tree_num):\n", - " x_train_enc[:, i * client_tree_num + j] = single_tree_prediction(\n", - " temp_trees[i], j, x_train\n", - " )\n", - "\n", - " x_train_enc32: Any = np.float32(x_train_enc)\n", - " y_train32: Any = np.float32(y_train)\n", - "\n", - " x_train_enc32, y_train32 = torch.from_numpy(\n", - " np.expand_dims(x_train_enc32, axis=1) # type: ignore\n", - " ), torch.from_numpy(\n", - " np.expand_dims(y_train32, axis=-1) # type: ignore\n", - " )\n", - " return x_train_enc32, y_train32" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Manually download and load the tabular dataset from LIBSVM data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "executionInfo": { - "elapsed": 26613, - "status": "ok", - "timestamp": 1670356076585, - "user": { - "displayName": "Chenyang Ma", - "userId": "17975430055716133031" - }, - "user_tz": 0 - }, - "outputId": "22843504-faf0-44cf-aedd-1df8d0ec87a6" - }, - "outputs": [], - "source": [ - "# Datasets can be downloaded from LIBSVM Data: https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/\n", - "binary_train = [\"cod-rna.t\", \"cod-rna\", \"ijcnn1.t\"]\n", - "binary_test = [\"cod-rna.r\", \"cod-rna.t\", \"ijcnn1.tr\"]\n", - "reg_train = [\"eunite2001\", \"YearPredictionMSD\"]\n", - "reg_test = [\"eunite2001.t\", \"YearPredictionMSD.t\"]\n", - "\n", - "# Define the type of training task. Binary classification: BINARY; Regression: REG\n", - "task_types = [\"BINARY\", \"REG\"]\n", - "task_type = task_types[0]\n", - "\n", - "# Select the downloaded training and test dataset\n", - "if task_type == \"BINARY\":\n", - " dataset_path = \"dataset/binary_classification/\"\n", - " train = binary_train[0]\n", - " test = binary_test[0]\n", - "elif task_type == \"REG\":\n", - " dataset_path = \"dataset/regression/\"\n", - " train = reg_train[0]\n", - " test = reg_test[0]\n", - "\n", - "data_train = load_svmlight_file(dataset_path + train, zero_based=False)\n", - "data_test = load_svmlight_file(dataset_path + test, zero_based=False)\n", - "\n", - "print(\"Task type selected is: \" + task_type)\n", - "print(\"Training dataset is: \" + train)\n", - "print(\"Test dataset is: \" + test)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Preprocess the tabular dataset" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class TreeDataset(Dataset):\n", - " def __init__(self, data: NDArray, labels: NDArray) -> None:\n", - " self.labels = labels\n", - " self.data = data\n", - "\n", - " def __len__(self) -> int:\n", - " return len(self.labels)\n", - "\n", - " def __getitem__(self, idx: int) -> Dict[int, NDArray]:\n", - " label = self.labels[idx]\n", - " data = self.data[idx, :]\n", - " sample = {0: data, 1: label}\n", - " return sample\n", - "\n", - "\n", - "X_train = data_train[0].toarray()\n", - "y_train = data_train[1]\n", - "X_test = data_test[0].toarray()\n", - "y_test = data_test[1]\n", - "X_train.flags.writeable = True\n", - "y_train.flags.writeable = True\n", - "X_test.flags.writeable = True\n", - "y_test.flags.writeable = True\n", - "\n", - "# If the feature dimensions of the trainset and testset do not agree,\n", - "# specify n_features in the load_svmlight_file function in the above cell.\n", - "# https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_svmlight_file.html\n", - "print(\"Feature dimension of the dataset:\", X_train.shape[1])\n", - "print(\"Size of the trainset:\", X_train.shape[0])\n", - "print(\"Size of the testset:\", X_test.shape[0])\n", - "assert X_train.shape[1] == X_test.shape[1]\n", - "\n", - "if task_type == \"BINARY\":\n", - " y_train[y_train == -1] = 0\n", - " y_test[y_test == -1] = 0\n", - "\n", - "trainset = TreeDataset(np.array(X_train, copy=True), np.array(y_train, copy=True))\n", - "testset = TreeDataset(np.array(X_test, copy=True), np.array(y_test, copy=True))" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Conduct tabular dataset partition for Federated Learning" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def get_dataloader(\n", - " dataset: Dataset, partition: str, batch_size: Union[int, str]\n", - ") -> DataLoader:\n", - " if batch_size == \"whole\":\n", - " batch_size = len(dataset)\n", - " return DataLoader(\n", - " dataset, batch_size=batch_size, pin_memory=True, shuffle=(partition == \"train\")\n", - " )\n", - "\n", - "\n", - "# https://github.com/adap/flower\n", - "def do_fl_partitioning(\n", - " trainset: Dataset,\n", - " testset: Dataset,\n", - " pool_size: int,\n", - " batch_size: Union[int, str],\n", - " val_ratio: float = 0.0,\n", - ") -> Tuple[DataLoader, DataLoader, DataLoader]:\n", - " # Split training set into `num_clients` partitions to simulate different local datasets\n", - " partition_size = len(trainset) // pool_size\n", - " lengths = [partition_size] * pool_size\n", - " if sum(lengths) != len(trainset):\n", - " lengths[-1] = len(trainset) - sum(lengths[0:-1])\n", - " datasets = random_split(trainset, lengths, torch.Generator().manual_seed(0))\n", - "\n", - " # Split each partition into train/val and create DataLoader\n", - " trainloaders = []\n", - " valloaders = []\n", - " for ds in datasets:\n", - " len_val = int(len(ds) * val_ratio)\n", - " len_train = len(ds) - len_val\n", - " lengths = [len_train, len_val]\n", - " ds_train, ds_val = random_split(ds, lengths, torch.Generator().manual_seed(0))\n", - " trainloaders.append(get_dataloader(ds_train, \"train\", batch_size))\n", - " if len_val != 0:\n", - " valloaders.append(get_dataloader(ds_val, \"val\", batch_size))\n", - " else:\n", - " valloaders = None\n", - " testloader = get_dataloader(testset, \"test\", batch_size)\n", - " return trainloaders, valloaders, testloader" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Define global variables for Federated XGBoost Learning" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# The number of clients participated in the federated learning\n", - "client_num = 5\n", - "\n", - "# The number of XGBoost trees in the tree ensemble that will be built for each client\n", - "client_tree_num = 500 // client_num" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Build global XGBoost tree for comparison" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "executionInfo": { - "elapsed": 1080216, - "status": "ok", - "timestamp": 1670357156788, - "user": { - "displayName": "Chenyang Ma", - "userId": "17975430055716133031" - }, - "user_tz": 0 - }, - "outputId": "d56f2821-5cd5-49ff-c5dc-f8d088eed799" - }, - "outputs": [], - "source": [ - "global_tree = construct_tree(X_train, y_train, client_tree_num, task_type)\n", - "preds_train = global_tree.predict(X_train)\n", - "preds_test = global_tree.predict(X_test)\n", - "\n", - "if task_type == \"BINARY\":\n", - " result_train = accuracy_score(y_train, preds_train)\n", - " result_test = accuracy_score(y_test, preds_test)\n", - " print(\"Global XGBoost Training Accuracy: %f\" % (result_train))\n", - " print(\"Global XGBoost Testing Accuracy: %f\" % (result_test))\n", - "elif task_type == \"REG\":\n", - " result_train = mean_squared_error(y_train, preds_train)\n", - " result_test = mean_squared_error(y_test, preds_test)\n", - " print(\"Global XGBoost Training MSE: %f\" % (result_train))\n", - " print(\"Global XGBoost Testing MSE: %f\" % (result_test))\n", - "\n", - "print(global_tree)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Simulate local XGBoost trees on clients for comparison" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "executionInfo": { - "elapsed": 242310, - "status": "ok", - "timestamp": 1670357399084, - "user": { - "displayName": "Chenyang Ma", - "userId": "17975430055716133031" - }, - "user_tz": 0 - }, - "outputId": "0739df9f-84de-4749-8de1-7bd7c6a32ccc" - }, - "outputs": [], - "source": [ - "client_trees_comparison = []\n", - "trainloaders, _, testloader = do_fl_partitioning(\n", - " trainset, testset, pool_size=client_num, batch_size=\"whole\", val_ratio=0.0\n", - ")\n", - "\n", - "for i, trainloader in enumerate(trainloaders):\n", - " for local_dataset in trainloader:\n", - " local_X_train, local_y_train = local_dataset[0], local_dataset[1]\n", - " tree = construct_tree(local_X_train, local_y_train, client_tree_num, task_type)\n", - " client_trees_comparison.append(tree)\n", - "\n", - " preds_train = client_trees_comparison[-1].predict(local_X_train)\n", - " preds_test = client_trees_comparison[-1].predict(X_test)\n", - "\n", - " if task_type == \"BINARY\":\n", - " result_train = accuracy_score(local_y_train, preds_train)\n", - " result_test = accuracy_score(y_test, preds_test)\n", - " print(\"Local Client %d XGBoost Training Accuracy: %f\" % (i, result_train))\n", - " print(\"Local Client %d XGBoost Testing Accuracy: %f\" % (i, result_test))\n", - " elif task_type == \"REG\":\n", - " result_train = mean_squared_error(local_y_train, preds_train)\n", - " result_test = mean_squared_error(y_test, preds_test)\n", - " print(\"Local Client %d XGBoost Training MSE: %f\" % (i, result_train))\n", - " print(\"Local Client %d XGBoost Testing MSE: %f\" % (i, result_test))" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Centralized Federated XGBoost\n", - "#### Create 1D convolutional neural network on trees prediction results. \n", - "#### 1D kernel size == client_tree_num\n", - "#### Make the learning rate of the tree ensembles learnable." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "executionInfo": { - "elapsed": 38, - "status": "ok", - "timestamp": 1670363021675, - "user": { - "displayName": "Chenyang Ma", - "userId": "17975430055716133031" - }, - "user_tz": 0 - } - }, - "outputs": [], - "source": [ - "class CNN(nn.Module):\n", - " def __init__(self, n_channel: int = 64) -> None:\n", - " super(CNN, self).__init__()\n", - " n_out = 1\n", - " self.task_type = task_type\n", - " self.conv1d = nn.Conv1d(\n", - " 1, n_channel, kernel_size=client_tree_num, stride=client_tree_num, padding=0\n", - " )\n", - " self.layer_direct = nn.Linear(n_channel * client_num, n_out)\n", - " self.ReLU = nn.ReLU()\n", - " self.Sigmoid = nn.Sigmoid()\n", - " self.Identity = nn.Identity()\n", - "\n", - " # Add weight initialization\n", - " for layer in self.modules():\n", - " if isinstance(layer, nn.Linear):\n", - " nn.init.kaiming_uniform_(\n", - " layer.weight, mode=\"fan_in\", nonlinearity=\"relu\"\n", - " )\n", - "\n", - " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", - " x = self.ReLU(self.conv1d(x))\n", - " x = x.flatten(start_dim=1)\n", - " x = self.ReLU(x)\n", - " if self.task_type == \"BINARY\":\n", - " x = self.Sigmoid(self.layer_direct(x))\n", - " elif self.task_type == \"REG\":\n", - " x = self.Identity(self.layer_direct(x))\n", - " return x\n", - "\n", - " def get_weights(self) -> fl.common.NDArrays:\n", - " \"\"\"Get model weights as a list of NumPy ndarrays.\"\"\"\n", - " return [\n", - " np.array(val.cpu().numpy(), copy=True)\n", - " for _, val in self.state_dict().items()\n", - " ]\n", - "\n", - " def set_weights(self, weights: fl.common.NDArrays) -> None:\n", - " \"\"\"Set model weights from a list of NumPy ndarrays.\"\"\"\n", - " layer_dict = {}\n", - " for k, v in zip(self.state_dict().keys(), weights):\n", - " if v.ndim != 0:\n", - " layer_dict[k] = torch.Tensor(np.array(v, copy=True))\n", - " state_dict = OrderedDict(layer_dict)\n", - " self.load_state_dict(state_dict, strict=True)\n", - "\n", - "\n", - "def train(\n", - " task_type: str,\n", - " net: CNN,\n", - " trainloader: DataLoader,\n", - " device: torch.device,\n", - " num_iterations: int,\n", - " log_progress: bool = True,\n", - ") -> Tuple[float, float, int]:\n", - " # Define loss and optimizer\n", - " if task_type == \"BINARY\":\n", - " criterion = nn.BCELoss()\n", - " elif task_type == \"REG\":\n", - " criterion = nn.MSELoss()\n", - " # optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-6)\n", - " optimizer = torch.optim.Adam(net.parameters(), lr=0.0001, betas=(0.9, 0.999))\n", - "\n", - " def cycle(iterable):\n", - " \"\"\"Repeats the contents of the train loader, in case it gets exhausted in 'num_iterations'.\"\"\"\n", - " while True:\n", - " for x in iterable:\n", - " yield x\n", - "\n", - " # Train the network\n", - " net.train()\n", - " total_loss, total_result, n_samples = 0.0, 0.0, 0\n", - " pbar = (\n", - " tqdm(iter(cycle(trainloader)), total=num_iterations, desc=f\"TRAIN\")\n", - " if log_progress\n", - " else iter(cycle(trainloader))\n", - " )\n", - "\n", - " # Unusually, this training is formulated in terms of number of updates/iterations/batches processed\n", - " # by the network. This will be helpful later on, when partitioning the data across clients: resulting\n", - " # in differences between dataset sizes and hence inconsistent numbers of updates per 'epoch'.\n", - " for i, data in zip(range(num_iterations), pbar):\n", - " tree_outputs, labels = data[0].to(device), data[1].to(device)\n", - " optimizer.zero_grad()\n", - "\n", - " outputs = net(tree_outputs)\n", - " loss = criterion(outputs, labels)\n", - " loss.backward()\n", - " optimizer.step()\n", - "\n", - " # Collected training loss and accuracy statistics\n", - " total_loss += loss.item()\n", - " n_samples += labels.size(0)\n", - "\n", - " if task_type == \"BINARY\":\n", - " acc = Accuracy(task=\"binary\")(outputs, labels.type(torch.int))\n", - " total_result += acc * labels.size(0)\n", - " elif task_type == \"REG\":\n", - " mse = MeanSquaredError()(outputs, labels.type(torch.int))\n", - " total_result += mse * labels.size(0)\n", - "\n", - " if log_progress:\n", - " if task_type == \"BINARY\":\n", - " pbar.set_postfix(\n", - " {\n", - " \"train_loss\": total_loss / n_samples,\n", - " \"train_acc\": total_result / n_samples,\n", - " }\n", - " )\n", - " elif task_type == \"REG\":\n", - " pbar.set_postfix(\n", - " {\n", - " \"train_loss\": total_loss / n_samples,\n", - " \"train_mse\": total_result / n_samples,\n", - " }\n", - " )\n", - " if log_progress:\n", - " print(\"\\n\")\n", - "\n", - " return total_loss / n_samples, total_result / n_samples, n_samples\n", - "\n", - "\n", - "def test(\n", - " task_type: str,\n", - " net: CNN,\n", - " testloader: DataLoader,\n", - " device: torch.device,\n", - " log_progress: bool = True,\n", - ") -> Tuple[float, float, int]:\n", - " \"\"\"Evaluates the network on test data.\"\"\"\n", - " if task_type == \"BINARY\":\n", - " criterion = nn.BCELoss()\n", - " elif task_type == \"REG\":\n", - " criterion = nn.MSELoss()\n", - "\n", - " total_loss, total_result, n_samples = 0.0, 0.0, 0\n", - " net.eval()\n", - " with torch.no_grad():\n", - " pbar = tqdm(testloader, desc=\"TEST\") if log_progress else testloader\n", - " for data in pbar:\n", - " tree_outputs, labels = data[0].to(device), data[1].to(device)\n", - " outputs = net(tree_outputs)\n", - "\n", - " # Collected testing loss and accuracy statistics\n", - " total_loss += criterion(outputs, labels).item()\n", - " n_samples += labels.size(0)\n", - "\n", - " if task_type == \"BINARY\":\n", - " acc = Accuracy(task=\"binary\")(\n", - " outputs.cpu(), labels.type(torch.int).cpu()\n", - " )\n", - " total_result += acc * labels.size(0)\n", - " elif task_type == \"REG\":\n", - " mse = MeanSquaredError()(outputs.cpu(), labels.type(torch.int).cpu())\n", - " total_result += mse * labels.size(0)\n", - "\n", - " if log_progress:\n", - " print(\"\\n\")\n", - "\n", - " return total_loss / n_samples, total_result / n_samples, n_samples" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Create Flower custom client\n", - "## Import Flower custom client relevant modules" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Flower client\n", - "from flwr.common import (\n", - " EvaluateIns,\n", - " EvaluateRes,\n", - " FitIns,\n", - " FitRes,\n", - " GetPropertiesIns,\n", - " GetPropertiesRes,\n", - " GetParametersIns,\n", - " GetParametersRes,\n", - " Status,\n", - " Code,\n", - " parameters_to_ndarrays,\n", - " ndarrays_to_parameters,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "executionInfo": { - "elapsed": 36, - "status": "ok", - "timestamp": 1670363021676, - "user": { - "displayName": "Chenyang Ma", - "userId": "17975430055716133031" - }, - "user_tz": 0 - } - }, - "outputs": [], - "source": [ - "def tree_encoding_loader(\n", - " dataloader: DataLoader,\n", - " batch_size: int,\n", - " client_trees: Union[\n", - " Tuple[XGBClassifier, int],\n", - " Tuple[XGBRegressor, int],\n", - " List[Union[Tuple[XGBClassifier, int], Tuple[XGBRegressor, int]]],\n", - " ],\n", - " client_tree_num: int,\n", - " client_num: int,\n", - ") -> DataLoader:\n", - " encoding = tree_encoding(dataloader, client_trees, client_tree_num, client_num)\n", - " if encoding is None:\n", - " return None\n", - " data, labels = encoding\n", - " tree_dataset = TreeDataset(data, labels)\n", - " return get_dataloader(tree_dataset, \"tree\", batch_size)\n", - "\n", - "\n", - "class FL_Client(fl.client.Client):\n", - " def __init__(\n", - " self,\n", - " task_type: str,\n", - " trainloader: DataLoader,\n", - " valloader: DataLoader,\n", - " client_tree_num: int,\n", - " client_num: int,\n", - " cid: str,\n", - " log_progress: bool = False,\n", - " ):\n", - " \"\"\"\n", - " Creates a client for training `network.Net` on tabular dataset.\n", - " \"\"\"\n", - " self.task_type = task_type\n", - " self.cid = cid\n", - " self.tree = construct_tree_from_loader(trainloader, client_tree_num, task_type)\n", - " self.trainloader_original = trainloader\n", - " self.valloader_original = valloader\n", - " self.trainloader = None\n", - " self.valloader = None\n", - " self.client_tree_num = client_tree_num\n", - " self.client_num = client_num\n", - " self.properties = {\"tensor_type\": \"numpy.ndarray\"}\n", - " self.log_progress = log_progress\n", - "\n", - " # instantiate model\n", - " self.net = CNN()\n", - "\n", - " # determine device\n", - " self.device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", - "\n", - " def get_properties(self, ins: GetPropertiesIns) -> GetPropertiesRes:\n", - " return GetPropertiesRes(properties=self.properties)\n", - "\n", - " def get_parameters(\n", - " self, ins: GetParametersIns\n", - " ) -> Tuple[\n", - " GetParametersRes, Union[Tuple[XGBClassifier, int], Tuple[XGBRegressor, int]]\n", - " ]:\n", - " return [\n", - " GetParametersRes(\n", - " status=Status(Code.OK, \"\"),\n", - " parameters=ndarrays_to_parameters(self.net.get_weights()),\n", - " ),\n", - " (self.tree, int(self.cid)),\n", - " ]\n", - "\n", - " def set_parameters(\n", - " self,\n", - " parameters: Tuple[\n", - " Parameters,\n", - " Union[\n", - " Tuple[XGBClassifier, int],\n", - " Tuple[XGBRegressor, int],\n", - " List[Union[Tuple[XGBClassifier, int], Tuple[XGBRegressor, int]]],\n", - " ],\n", - " ],\n", - " ) -> Union[\n", - " Tuple[XGBClassifier, int],\n", - " Tuple[XGBRegressor, int],\n", - " List[Union[Tuple[XGBClassifier, int], Tuple[XGBRegressor, int]]],\n", - " ]:\n", - " self.net.set_weights(parameters_to_ndarrays(parameters[0]))\n", - " return parameters[1]\n", - "\n", - " def fit(self, fit_params: FitIns) -> FitRes:\n", - " # Process incoming request to train\n", - " num_iterations = fit_params.config[\"num_iterations\"]\n", - " batch_size = fit_params.config[\"batch_size\"]\n", - " aggregated_trees = self.set_parameters(fit_params.parameters)\n", - "\n", - " if type(aggregated_trees) is list:\n", - " print(\"Client \" + self.cid + \": recieved\", len(aggregated_trees), \"trees\")\n", - " else:\n", - " print(\"Client \" + self.cid + \": only had its own tree\")\n", - " self.trainloader = tree_encoding_loader(\n", - " self.trainloader_original,\n", - " batch_size,\n", - " aggregated_trees,\n", - " self.client_tree_num,\n", - " self.client_num,\n", - " )\n", - " self.valloader = tree_encoding_loader(\n", - " self.valloader_original,\n", - " batch_size,\n", - " aggregated_trees,\n", - " self.client_tree_num,\n", - " self.client_num,\n", - " )\n", - "\n", - " # num_iterations = None special behaviour: train(...) runs for a single epoch, however many updates it may be\n", - " num_iterations = num_iterations or len(self.trainloader)\n", - "\n", - " # Train the model\n", - " print(f\"Client {self.cid}: training for {num_iterations} iterations/updates\")\n", - " self.net.to(self.device)\n", - " train_loss, train_result, num_examples = train(\n", - " self.task_type,\n", - " self.net,\n", - " self.trainloader,\n", - " device=self.device,\n", - " num_iterations=num_iterations,\n", - " log_progress=self.log_progress,\n", - " )\n", - " print(\n", - " f\"Client {self.cid}: training round complete, {num_examples} examples processed\"\n", - " )\n", - "\n", - " # Return training information: model, number of examples processed and metrics\n", - " if self.task_type == \"BINARY\":\n", - " return FitRes(\n", - " status=Status(Code.OK, \"\"),\n", - " parameters=self.get_parameters(fit_params.config),\n", - " num_examples=num_examples,\n", - " metrics={\"loss\": train_loss, \"accuracy\": train_result},\n", - " )\n", - " elif self.task_type == \"REG\":\n", - " return FitRes(\n", - " status=Status(Code.OK, \"\"),\n", - " parameters=self.get_parameters(fit_params.config),\n", - " num_examples=num_examples,\n", - " metrics={\"loss\": train_loss, \"mse\": train_result},\n", - " )\n", - "\n", - " def evaluate(self, eval_params: EvaluateIns) -> EvaluateRes:\n", - " # Process incoming request to evaluate\n", - " self.set_parameters(eval_params.parameters)\n", - "\n", - " # Evaluate the model\n", - " self.net.to(self.device)\n", - " loss, result, num_examples = test(\n", - " self.task_type,\n", - " self.net,\n", - " self.valloader,\n", - " device=self.device,\n", - " log_progress=self.log_progress,\n", - " )\n", - "\n", - " # Return evaluation information\n", - " if self.task_type == \"BINARY\":\n", - " print(\n", - " f\"Client {self.cid}: evaluation on {num_examples} examples: loss={loss:.4f}, accuracy={result:.4f}\"\n", - " )\n", - " return EvaluateRes(\n", - " status=Status(Code.OK, \"\"),\n", - " loss=loss,\n", - " num_examples=num_examples,\n", - " metrics={\"accuracy\": result},\n", - " )\n", - " elif self.task_type == \"REG\":\n", - " print(\n", - " f\"Client {self.cid}: evaluation on {num_examples} examples: loss={loss:.4f}, mse={result:.4f}\"\n", - " )\n", - " return EvaluateRes(\n", - " status=Status(Code.OK, \"\"),\n", - " loss=loss,\n", - " num_examples=num_examples,\n", - " metrics={\"mse\": result},\n", - " )" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Create Flower custom server\n", - "## Import Flower custom server relevant modules" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Flower server\n", - "import functools\n", - "from flwr.server.strategy import FedXgbNnAvg\n", - "from flwr.server.app import ServerConfig\n", - "\n", - "import timeit\n", - "from logging import DEBUG, INFO\n", - "from typing import Dict, List, Optional, Tuple, Union\n", - "\n", - "from flwr.common import DisconnectRes, Parameters, ReconnectIns, Scalar\n", - "from flwr.common.logger import log\n", - "from flwr.common.typing import GetParametersIns\n", - "from flwr.server.client_manager import ClientManager, SimpleClientManager\n", - "from flwr.server.client_proxy import ClientProxy\n", - "from flwr.server.history import History\n", - "from flwr.server.strategy import Strategy\n", - "from flwr.server.server import (\n", - " reconnect_clients,\n", - " reconnect_client,\n", - " fit_clients,\n", - " fit_client,\n", - " _handle_finished_future_after_fit,\n", - " evaluate_clients,\n", - " evaluate_client,\n", - " _handle_finished_future_after_evaluate,\n", - ")\n", - "\n", - "FitResultsAndFailures = Tuple[\n", - " List[Tuple[ClientProxy, FitRes]],\n", - " List[Union[Tuple[ClientProxy, FitRes], BaseException]],\n", - "]\n", - "EvaluateResultsAndFailures = Tuple[\n", - " List[Tuple[ClientProxy, EvaluateRes]],\n", - " List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]],\n", - "]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class FL_Server(fl.server.Server):\n", - " \"\"\"Flower server.\"\"\"\n", - "\n", - " def __init__(\n", - " self, *, client_manager: ClientManager, strategy: Optional[Strategy] = None\n", - " ) -> None:\n", - " self._client_manager: ClientManager = client_manager\n", - " self.parameters: Parameters = Parameters(\n", - " tensors=[], tensor_type=\"numpy.ndarray\"\n", - " )\n", - " self.strategy: Strategy = strategy\n", - " self.max_workers: Optional[int] = None\n", - "\n", - " # pylint: disable=too-many-locals\n", - " def fit(self, num_rounds: int, timeout: Optional[float]) -> History:\n", - " \"\"\"Run federated averaging for a number of rounds.\"\"\"\n", - " history = History()\n", - "\n", - " # Initialize parameters\n", - " log(INFO, \"Initializing global parameters\")\n", - " self.parameters = self._get_initial_parameters(timeout=timeout)\n", - "\n", - " log(INFO, \"Evaluating initial parameters\")\n", - " res = self.strategy.evaluate(0, parameters=self.parameters)\n", - " if res is not None:\n", - " log(\n", - " INFO,\n", - " \"initial parameters (loss, other metrics): %s, %s\",\n", - " res[0],\n", - " res[1],\n", - " )\n", - " history.add_loss_centralized(server_round=0, loss=res[0])\n", - " history.add_metrics_centralized(server_round=0, metrics=res[1])\n", - "\n", - " # Run federated learning for num_rounds\n", - " log(INFO, \"FL starting\")\n", - " start_time = timeit.default_timer()\n", - "\n", - " for current_round in range(1, num_rounds + 1):\n", - " # Train model and replace previous global model\n", - " res_fit = self.fit_round(server_round=current_round, timeout=timeout)\n", - " if res_fit:\n", - " parameters_prime, _, _ = res_fit # fit_metrics_aggregated\n", - " if parameters_prime:\n", - " self.parameters = parameters_prime\n", - "\n", - " # Evaluate model using strategy implementation\n", - " res_cen = self.strategy.evaluate(current_round, parameters=self.parameters)\n", - " if res_cen is not None:\n", - " loss_cen, metrics_cen = res_cen\n", - " log(\n", - " INFO,\n", - " \"fit progress: (%s, %s, %s, %s)\",\n", - " current_round,\n", - " loss_cen,\n", - " metrics_cen,\n", - " timeit.default_timer() - start_time,\n", - " )\n", - " history.add_loss_centralized(server_round=current_round, loss=loss_cen)\n", - " history.add_metrics_centralized(\n", - " server_round=current_round, metrics=metrics_cen\n", - " )\n", - "\n", - " # Evaluate model on a sample of available clients\n", - " res_fed = self.evaluate_round(server_round=current_round, timeout=timeout)\n", - " if res_fed:\n", - " loss_fed, evaluate_metrics_fed, _ = res_fed\n", - " if loss_fed:\n", - " history.add_loss_distributed(\n", - " server_round=current_round, loss=loss_fed\n", - " )\n", - " history.add_metrics_distributed(\n", - " server_round=current_round, metrics=evaluate_metrics_fed\n", - " )\n", - "\n", - " # Bookkeeping\n", - " end_time = timeit.default_timer()\n", - " elapsed = end_time - start_time\n", - " log(INFO, \"FL finished in %s\", elapsed)\n", - " return history\n", - "\n", - " def evaluate_round(\n", - " self,\n", - " server_round: int,\n", - " timeout: Optional[float],\n", - " ) -> Optional[\n", - " Tuple[Optional[float], Dict[str, Scalar], EvaluateResultsAndFailures]\n", - " ]:\n", - " \"\"\"Validate current global model on a number of clients.\"\"\"\n", - "\n", - " # Get clients and their respective instructions from strategy\n", - " client_instructions = self.strategy.configure_evaluate(\n", - " server_round=server_round,\n", - " parameters=self.parameters,\n", - " client_manager=self._client_manager,\n", - " )\n", - " if not client_instructions:\n", - " log(INFO, \"evaluate_round %s: no clients selected, cancel\", server_round)\n", - " return None\n", - " log(\n", - " DEBUG,\n", - " \"evaluate_round %s: strategy sampled %s clients (out of %s)\",\n", - " server_round,\n", - " len(client_instructions),\n", - " self._client_manager.num_available(),\n", - " )\n", - "\n", - " # Collect `evaluate` results from all clients participating in this round\n", - " results, failures = evaluate_clients(\n", - " client_instructions,\n", - " max_workers=self.max_workers,\n", - " timeout=timeout,\n", - " )\n", - " log(\n", - " DEBUG,\n", - " \"evaluate_round %s received %s results and %s failures\",\n", - " server_round,\n", - " len(results),\n", - " len(failures),\n", - " )\n", - "\n", - " # Aggregate the evaluation results\n", - " aggregated_result: Tuple[\n", - " Optional[float],\n", - " Dict[str, Scalar],\n", - " ] = self.strategy.aggregate_evaluate(server_round, results, failures)\n", - "\n", - " loss_aggregated, metrics_aggregated = aggregated_result\n", - " return loss_aggregated, metrics_aggregated, (results, failures)\n", - "\n", - " def fit_round(\n", - " self,\n", - " server_round: int,\n", - " timeout: Optional[float],\n", - " ) -> Optional[\n", - " Tuple[\n", - " Optional[\n", - " Tuple[\n", - " Parameters,\n", - " Union[\n", - " Tuple[XGBClassifier, int],\n", - " Tuple[XGBRegressor, int],\n", - " List[\n", - " Union[Tuple[XGBClassifier, int], Tuple[XGBRegressor, int]]\n", - " ],\n", - " ],\n", - " ]\n", - " ],\n", - " Dict[str, Scalar],\n", - " FitResultsAndFailures,\n", - " ]\n", - " ]:\n", - " \"\"\"Perform a single round of federated averaging.\"\"\"\n", - "\n", - " # Get clients and their respective instructions from strategy\n", - " client_instructions = self.strategy.configure_fit(\n", - " server_round=server_round,\n", - " parameters=self.parameters,\n", - " client_manager=self._client_manager,\n", - " )\n", - "\n", - " if not client_instructions:\n", - " log(INFO, \"fit_round %s: no clients selected, cancel\", server_round)\n", - " return None\n", - " log(\n", - " DEBUG,\n", - " \"fit_round %s: strategy sampled %s clients (out of %s)\",\n", - " server_round,\n", - " len(client_instructions),\n", - " self._client_manager.num_available(),\n", - " )\n", - "\n", - " # Collect `fit` results from all clients participating in this round\n", - " results, failures = fit_clients(\n", - " client_instructions=client_instructions,\n", - " max_workers=self.max_workers,\n", - " timeout=timeout,\n", - " )\n", - "\n", - " log(\n", - " DEBUG,\n", - " \"fit_round %s received %s results and %s failures\",\n", - " server_round,\n", - " len(results),\n", - " len(failures),\n", - " )\n", - "\n", - " # Aggregate training results\n", - " NN_aggregated: Parameters\n", - " trees_aggregated: Union[\n", - " Tuple[XGBClassifier, int],\n", - " Tuple[XGBRegressor, int],\n", - " List[Union[Tuple[XGBClassifier, int], Tuple[XGBRegressor, int]]],\n", - " ]\n", - " metrics_aggregated: Dict[str, Scalar]\n", - " aggregated, metrics_aggregated = self.strategy.aggregate_fit(\n", - " server_round, results, failures\n", - " )\n", - " NN_aggregated, trees_aggregated = aggregated[0], aggregated[1]\n", - "\n", - " if type(trees_aggregated) is list:\n", - " print(\"Server side aggregated\", len(trees_aggregated), \"trees.\")\n", - " else:\n", - " print(\"Server side did not aggregate trees.\")\n", - "\n", - " return (\n", - " [NN_aggregated, trees_aggregated],\n", - " metrics_aggregated,\n", - " (results, failures),\n", - " )\n", - "\n", - " def _get_initial_parameters(\n", - " self, timeout: Optional[float]\n", - " ) -> Tuple[Parameters, Union[Tuple[XGBClassifier, int], Tuple[XGBRegressor, int]]]:\n", - " \"\"\"Get initial parameters from one of the available clients.\"\"\"\n", - "\n", - " # Server-side parameter initialization\n", - " parameters: Optional[Parameters] = self.strategy.initialize_parameters(\n", - " client_manager=self._client_manager\n", - " )\n", - " if parameters is not None:\n", - " log(INFO, \"Using initial parameters provided by strategy\")\n", - " return parameters\n", - "\n", - " # Get initial parameters from one of the clients\n", - " log(INFO, \"Requesting initial parameters from one random client\")\n", - " random_client = self._client_manager.sample(1)[0]\n", - " ins = GetParametersIns(config={})\n", - " get_parameters_res_tree = random_client.get_parameters(ins=ins, timeout=timeout)\n", - " parameters = [get_parameters_res_tree[0].parameters, get_parameters_res_tree[1]]\n", - " log(INFO, \"Received initial parameters from one random client\")\n", - "\n", - " return parameters" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Create server-side evaluation and experiment" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "executionInfo": { - "elapsed": 35, - "status": "ok", - "timestamp": 1670363021676, - "user": { - "displayName": "Chenyang Ma", - "userId": "17975430055716133031" - }, - "user_tz": 0 - } - }, - "outputs": [], - "source": [ - "def print_model_layers(model: nn.Module) -> None:\n", - " print(model)\n", - " for param_tensor in model.state_dict():\n", - " print(param_tensor, \"\\t\", model.state_dict()[param_tensor].size())\n", - "\n", - "\n", - "def serverside_eval(\n", - " server_round: int,\n", - " parameters: Tuple[\n", - " Parameters,\n", - " Union[\n", - " Tuple[XGBClassifier, int],\n", - " Tuple[XGBRegressor, int],\n", - " List[Union[Tuple[XGBClassifier, int], Tuple[XGBRegressor, int]]],\n", - " ],\n", - " ],\n", - " config: Dict[str, Scalar],\n", - " task_type: str,\n", - " testloader: DataLoader,\n", - " batch_size: int,\n", - " client_tree_num: int,\n", - " client_num: int,\n", - ") -> Tuple[float, Dict[str, float]]:\n", - " \"\"\"An evaluation function for centralized/serverside evaluation over the entire test set.\"\"\"\n", - " # device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", - " device = \"cpu\"\n", - " model = CNN()\n", - " # print_model_layers(model)\n", - "\n", - " model.set_weights(parameters_to_ndarrays(parameters[0]))\n", - " model.to(device)\n", - "\n", - " trees_aggregated = parameters[1]\n", - " testloader = tree_encoding_loader(\n", - " testloader, batch_size, trees_aggregated, client_tree_num, client_num\n", - " )\n", - " loss, result, _ = test(\n", - " task_type, model, testloader, device=device, log_progress=False\n", - " )\n", - "\n", - " if task_type == \"BINARY\":\n", - " print(\n", - " f\"Evaluation on the server: test_loss={loss:.4f}, test_accuracy={result:.4f}\"\n", - " )\n", - " return loss, {\"accuracy\": result}\n", - " elif task_type == \"REG\":\n", - " print(f\"Evaluation on the server: test_loss={loss:.4f}, test_mse={result:.4f}\")\n", - " return loss, {\"mse\": result}\n", - "\n", - "\n", - "def start_experiment(\n", - " task_type: str,\n", - " trainset: Dataset,\n", - " testset: Dataset,\n", - " num_rounds: int = 5,\n", - " client_tree_num: int = 50,\n", - " client_pool_size: int = 5,\n", - " num_iterations: int = 100,\n", - " fraction_fit: float = 1.0,\n", - " min_fit_clients: int = 2,\n", - " batch_size: int = 32,\n", - " val_ratio: float = 0.1,\n", - ") -> History:\n", - " client_resources = {\"num_cpus\": 0.5} # 2 clients per CPU\n", - "\n", - " # Partition the dataset into subsets reserved for each client.\n", - " # - 'val_ratio' controls the proportion of the (local) client reserved as a local test set\n", - " # (good for testing how the final model performs on the client's local unseen data)\n", - " trainloaders, valloaders, testloader = do_fl_partitioning(\n", - " trainset,\n", - " testset,\n", - " batch_size=\"whole\",\n", - " pool_size=client_pool_size,\n", - " val_ratio=val_ratio,\n", - " )\n", - " print(\n", - " f\"Data partitioned across {client_pool_size} clients\"\n", - " f\" and {val_ratio} of local dataset reserved for validation.\"\n", - " )\n", - "\n", - " # Configure the strategy\n", - " def fit_config(server_round: int) -> Dict[str, Scalar]:\n", - " print(f\"Configuring round {server_round}\")\n", - " return {\n", - " \"num_iterations\": num_iterations,\n", - " \"batch_size\": batch_size,\n", - " }\n", - "\n", - " # FedXgbNnAvg\n", - " strategy = FedXgbNnAvg(\n", - " fraction_fit=fraction_fit,\n", - " fraction_evaluate=fraction_fit if val_ratio > 0.0 else 0.0,\n", - " min_fit_clients=min_fit_clients,\n", - " min_evaluate_clients=min_fit_clients,\n", - " min_available_clients=client_pool_size, # all clients should be available\n", - " on_fit_config_fn=fit_config,\n", - " on_evaluate_config_fn=(lambda r: {\"batch_size\": batch_size}),\n", - " evaluate_fn=functools.partial(\n", - " serverside_eval,\n", - " task_type=task_type,\n", - " testloader=testloader,\n", - " batch_size=batch_size,\n", - " client_tree_num=client_tree_num,\n", - " client_num=client_num,\n", - " ),\n", - " accept_failures=False,\n", - " )\n", - "\n", - " print(\n", - " f\"FL experiment configured for {num_rounds} rounds with {client_pool_size} client in the pool.\"\n", - " )\n", - " print(\n", - " f\"FL round will proceed with {fraction_fit * 100}% of clients sampled, at least {min_fit_clients}.\"\n", - " )\n", - "\n", - " def client_fn(cid: str) -> fl.client.Client:\n", - " \"\"\"Creates a federated learning client\"\"\"\n", - " if val_ratio > 0.0 and val_ratio <= 1.0:\n", - " return FL_Client(\n", - " task_type,\n", - " trainloaders[int(cid)],\n", - " valloaders[int(cid)],\n", - " client_tree_num,\n", - " client_pool_size,\n", - " cid,\n", - " log_progress=False,\n", - " )\n", - " else:\n", - " return FL_Client(\n", - " task_type,\n", - " trainloaders[int(cid)],\n", - " None,\n", - " client_tree_num,\n", - " client_pool_size,\n", - " cid,\n", - " log_progress=False,\n", - " )\n", - "\n", - " # Start the simulation\n", - " history = fl.simulation.start_simulation(\n", - " client_fn=client_fn,\n", - " server=FL_Server(client_manager=SimpleClientManager(), strategy=strategy),\n", - " num_clients=client_pool_size,\n", - " client_resources=client_resources,\n", - " config=ServerConfig(num_rounds=num_rounds),\n", - " strategy=strategy,\n", - " )\n", - "\n", - " print(history)\n", - "\n", - " return history" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Start federated training and inference\n", - "#### High-level workflow: \n", - "#### At round 1, each client first builds their own local XGBoost tree, and sends to the server. The server aggregates all trees and sends to all clients. \n", - "#### After round 1, each client calculates every other client tree’s prediction results, and trains a convolutional neural network with 1D convolution kernel size == the number of XGBoost trees in the tree ensemble. \n", - "#### The sharing of privacy-sensitive information is not needed, and the learning rate (a hyperparameter for XGBoost) is learnable using 1D convolution." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 624 - }, - "executionInfo": { - "elapsed": 7610, - "status": "error", - "timestamp": 1670363029252, - "user": { - "displayName": "Chenyang Ma", - "userId": "17975430055716133031" - }, - "user_tz": 0 - }, - "outputId": "ee2b7146-07ec-4f97-ba44-5b12b35bbeaf" - }, - "outputs": [], - "source": [ - "start_experiment(\n", - " task_type=task_type,\n", - " trainset=trainset,\n", - " testset=testset,\n", - " num_rounds=20,\n", - " client_tree_num=client_tree_num,\n", - " client_pool_size=client_num,\n", - " num_iterations=100,\n", - " batch_size=64,\n", - " fraction_fit=1.0,\n", - " min_fit_clients=1,\n", - " val_ratio=0.0,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "colab": { - "provenance": [] - }, - "gpuClass": "premium", - "kernelspec": { - "display_name": "FedXGBoost", - "language": "python", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/examples/secaggplus-mt/README.md b/examples/secaggplus-mt/README.md deleted file mode 100644 index 0b3b4db3942e..000000000000 --- a/examples/secaggplus-mt/README.md +++ /dev/null @@ -1,36 +0,0 @@ -# Secure Aggregation with Driver API - -This example contains highly experimental code. Please consult the regular PyTorch code examples ([quickstart](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch), [advanced](https://github.com/adap/flower/tree/main/examples/advanced-pytorch)) to learn how to use Flower with PyTorch. - -## Installing Dependencies - -Project dependencies (such as and `flwr`) are defined in `pyproject.toml`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences. - -### Poetry - -```shell -poetry install -poetry shell -``` - -Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command: - -```shell -poetry run python3 -c "import flwr" -``` - -### pip - -Write the command below in your terminal to install the dependencies according to the configuration file requirements.txt. - -```shell -pip install -r requirements.txt -``` - -If you don't see any errors you're good to go! - -## Run with Driver API - -```bash -./run.sh -``` diff --git a/examples/secaggplus-mt/client.py b/examples/secaggplus-mt/client.py deleted file mode 100644 index f0f1348ee378..000000000000 --- a/examples/secaggplus-mt/client.py +++ /dev/null @@ -1,35 +0,0 @@ -import time - -import numpy as np - -import flwr as fl -from flwr.common import Status, FitIns, FitRes, Code -from flwr.common.parameter import ndarrays_to_parameters -from flwr.client.secure_aggregation import SecAggPlusHandler - - -# Define Flower client with the SecAgg+ protocol -class FlowerClient(fl.client.Client, SecAggPlusHandler): - def fit(self, fit_ins: FitIns) -> FitRes: - ret_vec = [np.ones(3)] - ret = FitRes( - status=Status(code=Code.OK, message="Success"), - parameters=ndarrays_to_parameters(ret_vec), - num_examples=1, - metrics={}, - ) - # Force a significant delay for testing purposes - if self._shared_state.sid == 0: - print(f"Client {self._shared_state.sid} dropped for testing purposes.") - time.sleep(4) - return ret - print(f"Client {self._shared_state.sid} uploading {ret_vec[0]}...") - return ret - - -# Start Flower client -fl.client.start_client( - server_address="0.0.0.0:9092", - client=FlowerClient(), - transport="grpc-rere", -) diff --git a/examples/secaggplus-mt/driver.py b/examples/secaggplus-mt/driver.py deleted file mode 100644 index d9f795766f6d..000000000000 --- a/examples/secaggplus-mt/driver.py +++ /dev/null @@ -1,206 +0,0 @@ -import random -import time -from typing import Dict, List, Tuple - -import numpy as np -from workflows import get_workflow_factory - -from flwr.common import Metrics, ndarrays_to_parameters -from flwr.driver import GrpcDriver -from flwr.proto import driver_pb2, node_pb2, task_pb2 -from flwr.server import History - - -# Convert instruction/result dict to/from list of TaskIns/TaskRes -def task_dict_to_task_ins_list( - task_dict: Dict[int, task_pb2.Task] -) -> List[task_pb2.TaskIns]: - def merge(_task: task_pb2.Task, _merge_task: task_pb2.Task) -> task_pb2.Task: - _task.MergeFrom(_merge_task) - return _task - - return [ - task_pb2.TaskIns( - task_id="", # Do not set, will be created and set by the DriverAPI - group_id="", - workload_id=workload_id, - task=merge( - task, - task_pb2.Task( - producer=node_pb2.Node( - node_id=0, - anonymous=True, - ), - consumer=node_pb2.Node( - node_id=sampled_node_id, - # Must be False for this Secure Aggregation example - anonymous=False, - ), - ), - ), - ) - for sampled_node_id, task in task_dict.items() - ] - - -def task_res_list_to_task_dict( - task_res_list: List[task_pb2.TaskRes], -) -> Dict[int, task_pb2.Task]: - return {task_res.task.producer.node_id: task_res.task for task_res in task_res_list} - - -# Define metric aggregation function -def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: - examples = [num_examples for num_examples, _ in metrics] - - # Multiply accuracy of each client by number of examples used - train_losses = [num_examples * m["train_loss"] for num_examples, m in metrics] - train_accuracies = [ - num_examples * m["train_accuracy"] for num_examples, m in metrics - ] - val_losses = [num_examples * m["val_loss"] for num_examples, m in metrics] - val_accuracies = [num_examples * m["val_accuracy"] for num_examples, m in metrics] - - # Aggregate and return custom metric (weighted average) - return { - "train_loss": sum(train_losses) / sum(examples), - "train_accuracy": sum(train_accuracies) / sum(examples), - "val_loss": sum(val_losses) / sum(examples), - "val_accuracy": sum(val_accuracies) / sum(examples), - } - - -# -------------------------------------------------------------------------- Driver SDK -driver = GrpcDriver(driver_service_address="0.0.0.0:9091", certificates=None) -# -------------------------------------------------------------------------- Driver SDK - -anonymous_client_nodes = False -num_client_nodes_per_round = 5 -sleep_time = 0.5 -time_out = 3.9 -num_rounds = 3 -parameters = ndarrays_to_parameters([np.ones(3)]) -wf_factory = get_workflow_factory() - -# -------------------------------------------------------------------------- Driver SDK -driver.connect() -create_workload_res: driver_pb2.CreateWorkloadResponse = driver.create_workload( - req=driver_pb2.CreateWorkloadRequest() -) -# -------------------------------------------------------------------------- Driver SDK - -workload_id = create_workload_res.workload_id -print(f"Created workload id {workload_id}") - -history = History() -for server_round in range(num_rounds): - print(f"Commencing server round {server_round + 1}") - - # List of sampled node IDs in this round - sampled_node_ids: List[int] = [] - - # Sample node ids - if anonymous_client_nodes: - # If we're working with anonymous clients, we don't know their identities, and - # we don't know how many of them we have. We, therefore, have to assume that - # enough anonymous client nodes are available or become available over time. - # - # To schedule a TaskIns for an anonymous client node, we set the node_id to 0 - # (and `anonymous` to True) - # Here, we create an array with only zeros in it: - sampled_node_ids = [0] * num_client_nodes_per_round - else: - # If our client nodes have identiy (i.e., they are not anonymous), we can get - # those IDs from the Driver API using `get_nodes`. If enough clients are - # available via the Driver API, we can select a subset by taking a random - # sample. - # - # The Driver API might not immediately return enough client node IDs, so we - # loop and wait until enough client nodes are available. - while True: - # Get a list of node ID's from the server - get_nodes_req = driver_pb2.GetNodesRequest(workload_id=workload_id) - - # ---------------------------------------------------------------------- Driver SDK - get_nodes_res: driver_pb2.GetNodesResponse = driver.get_nodes( - req=get_nodes_req - ) - # ---------------------------------------------------------------------- Driver SDK - - all_node_ids: List[int] = [node.node_id for node in get_nodes_res.nodes] - - if len(all_node_ids) >= num_client_nodes_per_round: - # Sample client nodes - sampled_node_ids = random.sample( - all_node_ids, num_client_nodes_per_round - ) - break - - time.sleep(3) - - # Log sampled node IDs - time.sleep(sleep_time) - - workflow = wf_factory(parameters, sampled_node_ids) - node_messages = None - - while True: - try: - instructions: Dict[int, task_pb2.Task] = workflow.send(node_messages) - next(workflow) - except StopIteration: - break - # Schedule a task for all sampled nodes - task_ins_list: List[task_pb2.TaskIns] = task_dict_to_task_ins_list(instructions) - - push_task_ins_req = driver_pb2.PushTaskInsRequest(task_ins_list=task_ins_list) - - # ---------------------------------------------------------------------- Driver SDK - push_task_ins_res: driver_pb2.PushTaskInsResponse = driver.push_task_ins( - req=push_task_ins_req - ) - # ---------------------------------------------------------------------- Driver SDK - - time.sleep(sleep_time) - - # Wait for results, ignore empty task_ids - start_time = time.time() - task_ids: List[str] = [ - task_id for task_id in push_task_ins_res.task_ids if task_id != "" - ] - all_task_res: List[task_pb2.TaskRes] = [] - while True: - if time.time() - start_time >= time_out: - break - pull_task_res_req = driver_pb2.PullTaskResRequest( - node=node_pb2.Node(node_id=0, anonymous=True), - task_ids=task_ids, - ) - - # ------------------------------------------------------------------ Driver SDK - pull_task_res_res: driver_pb2.PullTaskResResponse = driver.pull_task_res( - req=pull_task_res_req - ) - # ------------------------------------------------------------------ Driver SDK - - task_res_list: List[task_pb2.TaskRes] = pull_task_res_res.task_res_list - - time.sleep(sleep_time) - - all_task_res += task_res_list - if len(all_task_res) == len(task_ids): - break - - # Collect correct results - node_messages = task_res_list_to_task_dict( - [res for res in all_task_res if res.task.HasField("sa")] - ) - workflow.close() - - # Slow down the start of the next round - time.sleep(sleep_time) - -# -------------------------------------------------------------------------- Driver SDK -driver.disconnect() -# -------------------------------------------------------------------------- Driver SDK -print("Driver disconnected") diff --git a/examples/secaggplus-mt/pyproject.toml b/examples/secaggplus-mt/pyproject.toml deleted file mode 100644 index 94d8defa3316..000000000000 --- a/examples/secaggplus-mt/pyproject.toml +++ /dev/null @@ -1,13 +0,0 @@ -[build-system] -requires = ["poetry-core>=1.4.0"] -build-backend = "poetry.core.masonry.api" - -[tool.poetry] -name = "secaggplus-mt" -version = "0.1.0" -description = "Secure Aggregation with Driver API" -authors = ["The Flower Authors "] - -[tool.poetry.dependencies] -python = ">=3.8,<3.11" -flwr-nightly = { version = "^1.5.0.dev20230629", extras = ["simulation", "rest"] } diff --git a/examples/secaggplus-mt/requirements.txt b/examples/secaggplus-mt/requirements.txt deleted file mode 100644 index eeed6941afc7..000000000000 --- a/examples/secaggplus-mt/requirements.txt +++ /dev/null @@ -1 +0,0 @@ -flwr-nightly[simulation,rest] diff --git a/examples/secaggplus-mt/run.sh b/examples/secaggplus-mt/run.sh deleted file mode 100755 index 5cc769f6cbd8..000000000000 --- a/examples/secaggplus-mt/run.sh +++ /dev/null @@ -1,35 +0,0 @@ -#!/bin/bash -# Kill any currently running client.py processes -pkill -f 'python client.py' - -# Kill any currently running flower-server processes with --grpc-rere option -pkill -f 'flower-server --grpc-rere' - -# Start the flower server -echo "Starting flower server in background..." -flower-server --grpc-rere > /dev/null 2>&1 & -sleep 2 - -# Number of client processes to start -N=5 # Replace with your desired value - -echo "Starting $N clients in background..." - -# Start N client processes -for i in $(seq 1 $N) -do - python client.py > /dev/null 2>&1 & - # python client.py & - sleep 0.1 -done - -echo "Starting driver..." -python driver.py - -echo "Clearing background processes..." - -# Kill any currently running client.py processes -pkill -f 'python client.py' - -# Kill any currently running flower-server processes with --grpc-rere option -pkill -f 'flower-server --grpc-rere' diff --git a/examples/secaggplus-mt/workflows.py b/examples/secaggplus-mt/workflows.py deleted file mode 100644 index 3117e308a498..000000000000 --- a/examples/secaggplus-mt/workflows.py +++ /dev/null @@ -1,372 +0,0 @@ -import random -from logging import WARNING -from typing import Callable, Dict, Generator, List - -import numpy as np - -from flwr.common import ( - Parameters, - Scalar, - bytes_to_ndarray, - log, - ndarray_to_bytes, - ndarrays_to_parameters, - parameters_to_ndarrays, -) -from flwr.common.secure_aggregation.crypto.shamir import combine_shares -from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( - bytes_to_private_key, - bytes_to_public_key, - generate_shared_key, -) -from flwr.common.secure_aggregation.ndarrays_arithmetic import ( - factor_extract, - get_parameters_shape, - get_zero_parameters, - parameters_addition, - parameters_mod, - parameters_subtraction, -) -from flwr.common.secure_aggregation.quantization import dequantize, quantize -from flwr.common.secure_aggregation.secaggplus_constants import ( - KEY_ACTIVE_SECURE_ID_LIST, - KEY_CIPHERTEXT_LIST, - KEY_CLIPPING_RANGE, - KEY_DEAD_SECURE_ID_LIST, - KEY_DESTINATION_LIST, - KEY_MASKED_PARAMETERS, - KEY_MOD_RANGE, - KEY_PARAMETERS, - KEY_PUBLIC_KEY_1, - KEY_PUBLIC_KEY_2, - KEY_SAMPLE_NUMBER, - KEY_SECURE_ID, - KEY_SECURE_ID_LIST, - KEY_SHARE_LIST, - KEY_SHARE_NUMBER, - KEY_SOURCE_LIST, - KEY_STAGE, - KEY_TARGET_RANGE, - KEY_THRESHOLD, - STAGE_COLLECT_MASKED_INPUT, - STAGE_SETUP, - STAGE_SHARE_KEYS, - STAGE_UNMASK, -) -from flwr.common.secure_aggregation.secaggplus_utils import pseudo_rand_gen -from flwr.common.serde import named_values_from_proto, named_values_to_proto -from flwr.common.typing import Value -from flwr.proto.task_pb2 import SecureAggregation, Task - - -LOG_EXPLAIN = True - - -def get_workflow_factory() -> ( - Callable[[Parameters, List[int]], Generator[Dict[int, Task], Dict[int, Task], None]] -): - return _wrap_workflow_with_sec_agg - - -def _wrap_in_task(named_values: Dict[str, Value]) -> Task: - return Task(sa=SecureAggregation(named_values=named_values_to_proto(named_values))) - - -def _get_from_task(task: Task) -> Dict[str, Value]: - return named_values_from_proto(task.sa.named_values) - - -_secure_aggregation_configuration = { - KEY_SHARE_NUMBER: 3, - KEY_THRESHOLD: 2, - KEY_CLIPPING_RANGE: 3.0, - KEY_TARGET_RANGE: 1 << 20, - KEY_MOD_RANGE: 1 << 30, -} - - -def workflow_with_sec_agg( - parameters: Parameters, - sampled_node_ids: List[int], - sec_agg_config: Dict[str, Scalar], -) -> Generator[Dict[int, Task], Dict[int, Task], None]: - """ - =============== Setup stage =============== - """ - # Protocol config - num_samples = len(sampled_node_ids) - num_shares = sec_agg_config[KEY_SHARE_NUMBER] - threshold = sec_agg_config[KEY_THRESHOLD] - mod_range = sec_agg_config[KEY_MOD_RANGE] - # Quantization config - clipping_range = sec_agg_config[KEY_CLIPPING_RANGE] - target_range = sec_agg_config[KEY_TARGET_RANGE] - - if LOG_EXPLAIN: - _quantized = quantize( - [np.ones(3) for _ in range(num_samples)], clipping_range, target_range - ) - print( - "\n\n################################ Introduction ################################\n" - "In the example, each client will upload a vector [1.0, 1.0, 1.0] instead of\n" - "model updates for demonstration purposes.\n" - "Client 0 is configured to drop out before uploading the masked vector.\n" - f"After quantization, the raw vectors will be:" - ) - for i in range(1, num_samples): - print(f"\t{_quantized[i]} from Client {i}") - print( - f"Numbers are rounded to integers stochastically during the quantization\n" - ", and thus not all entries are identical." - ) - print( - "The above raw vectors are hidden from the driver through adding masks.\n" - ) - print( - "########################## Secure Aggregation Start ##########################" - ) - cfg = { - KEY_STAGE: STAGE_SETUP, - KEY_SAMPLE_NUMBER: num_samples, - KEY_SHARE_NUMBER: num_shares, - KEY_THRESHOLD: threshold, - KEY_CLIPPING_RANGE: clipping_range, - KEY_TARGET_RANGE: target_range, - KEY_MOD_RANGE: mod_range, - } - # The number of shares should better be odd in the SecAgg+ protocol. - if num_samples != num_shares and num_shares & 0x1 == 0: - log(WARNING, "Number of shares in the SecAgg+ protocol should be odd.") - num_shares += 1 - - # Randomly assign secure IDs to clients - sids = [i for i in range(len(sampled_node_ids))] - random.shuffle(sids) - nid2sid = dict(zip(sampled_node_ids, sids)) - sid2nid = {sid: nid for nid, sid in nid2sid.items()} - # Build neighbour relations (node ID -> secure IDs of neighbours) - half_share = num_shares >> 1 - nid2neighbours = { - node_id: { - (nid2sid[node_id] + offset) % num_samples - for offset in range(-half_share, half_share + 1) - } - for node_id in sampled_node_ids - } - - surviving_node_ids = sampled_node_ids - if LOG_EXPLAIN: - print( - f"Sending configurations to {num_samples} clients and allocating secure IDs..." - ) - # Send setup configuration to clients - yield { - node_id: _wrap_in_task( - named_values={ - **cfg, - KEY_SECURE_ID: nid2sid[node_id], - } - ) - for node_id in surviving_node_ids - } - # Receive public keys from clients and build the dict - node_messages = yield - surviving_node_ids = [node_id for node_id in node_messages] - - if LOG_EXPLAIN: - print(f"Received public keys from {len(surviving_node_ids)} clients.") - - sid2public_keys = {} - for node_id, task in node_messages.items(): - key_dict = _get_from_task(task) - pk1, pk2 = key_dict[KEY_PUBLIC_KEY_1], key_dict[KEY_PUBLIC_KEY_2] - sid2public_keys[nid2sid[node_id]] = [pk1, pk2] - - """ - =============== Share keys stage =============== - """ - if LOG_EXPLAIN: - print(f"\nForwarding public keys...") - # Broadcast public keys to clients - yield { - node_id: _wrap_in_task( - named_values={ - KEY_STAGE: STAGE_SHARE_KEYS, - **{ - str(sid): value - for sid, value in sid2public_keys.items() - if sid in nid2neighbours[node_id] - }, - } - ) - for node_id in surviving_node_ids - } - - # Receive secret key shares from clients - node_messages = yield - surviving_node_ids = [node_id for node_id in node_messages] - if LOG_EXPLAIN: - print(f"Received encrypted key shares from {len(surviving_node_ids)} clients.") - # Build forward packet list dictionary - srcs, dsts, ciphertexts = [], [], [] - fwd_ciphertexts: Dict[int, List[bytes]] = { - nid2sid[nid]: [] for nid in surviving_node_ids - } # dest secure ID -> list of ciphertexts - fwd_srcs: Dict[int, List[bytes]] = { - sid: [] for sid in fwd_ciphertexts - } # dest secure ID -> list of src secure IDs - for node_id, task in node_messages.items(): - res_dict = _get_from_task(task) - srcs += [nid2sid[node_id]] * len(res_dict[KEY_DESTINATION_LIST]) - dsts += res_dict[KEY_DESTINATION_LIST] - ciphertexts += res_dict[KEY_CIPHERTEXT_LIST] - - for src, dst, ciphertext in zip(srcs, dsts, ciphertexts): - if dst in fwd_ciphertexts: - fwd_ciphertexts[dst].append(ciphertext) - fwd_srcs[dst].append(src) - - """ - =============== Collect masked input stage =============== - """ - - if LOG_EXPLAIN: - print(f"\nForwarding encrypted key shares and requesting masked input...") - # Send encrypted secret key shares to clients (plus model parameters) - weights = parameters_to_ndarrays(parameters) - yield { - node_id: _wrap_in_task( - named_values={ - KEY_STAGE: STAGE_COLLECT_MASKED_INPUT, - KEY_CIPHERTEXT_LIST: fwd_ciphertexts[nid2sid[node_id]], - KEY_SOURCE_LIST: fwd_srcs[nid2sid[node_id]], - KEY_PARAMETERS: [ndarray_to_bytes(arr) for arr in weights], - } - ) - for node_id in surviving_node_ids - } - # Collect masked input from clients - node_messages = yield - surviving_node_ids = [node_id for node_id in node_messages] - # Get shape of vector sent by first client - masked_vector = [np.array([0], dtype=int)] + get_zero_parameters( - [w.shape for w in weights] - ) - # Add all collected masked vectors and compuute available and dropout clients set - dead_sids = { - nid2sid[node_id] - for node_id in sampled_node_ids - if node_id not in surviving_node_ids - } - active_sids = {nid2sid[node_id] for node_id in surviving_node_ids} - if LOG_EXPLAIN: - for sid in dead_sids: - print(f"Client {sid} dropped out.") - for node_id, task in node_messages.items(): - named_values = _get_from_task(task) - client_masked_vec = named_values[KEY_MASKED_PARAMETERS] - client_masked_vec = [bytes_to_ndarray(b) for b in client_masked_vec] - if LOG_EXPLAIN: - print(f"Received {client_masked_vec[1]} from Client {nid2sid[node_id]}.") - masked_vector = parameters_addition(masked_vector, client_masked_vec) - masked_vector = parameters_mod(masked_vector, mod_range) - """ - =============== Unmask stage =============== - """ - - if LOG_EXPLAIN: - print("\nRequesting key shares to unmask the aggregate vector...") - # Send secure IDs of active and dead clients. - yield { - node_id: _wrap_in_task( - named_values={ - KEY_STAGE: STAGE_UNMASK, - KEY_DEAD_SECURE_ID_LIST: list(dead_sids & nid2neighbours[node_id]), - KEY_ACTIVE_SECURE_ID_LIST: list(active_sids & nid2neighbours[node_id]), - } - ) - for node_id in surviving_node_ids - } - # Collect key shares from clients - node_messages = yield - surviving_node_ids = [node_id for node_id in node_messages] - if LOG_EXPLAIN: - print(f"Received key shares from {len(surviving_node_ids)} clients.") - # Build collected shares dict - collected_shares_dict: Dict[int, List[bytes]] = {} - for nid in sampled_node_ids: - collected_shares_dict[nid2sid[nid]] = [] - - if len(surviving_node_ids) < threshold: - raise Exception("Not enough available clients after unmask vectors stage") - for _, task in node_messages.items(): - named_values = _get_from_task(task) - for owner_sid, share in zip( - named_values[KEY_SECURE_ID_LIST], named_values[KEY_SHARE_LIST] - ): - collected_shares_dict[owner_sid].append(share) - # Remove mask for every client who is available before ask vectors stage, - # divide vector by first element - active_sids, dead_sids = set(active_sids), set(dead_sids) - for sid, share_list in collected_shares_dict.items(): - if len(share_list) < threshold: - raise Exception( - "Not enough shares to recover secret in unmask vectors stage" - ) - secret = combine_shares(share_list) - if sid in active_sids: - # The seed for PRG is the private mask seed of an active client. - private_mask = pseudo_rand_gen( - secret, mod_range, get_parameters_shape(masked_vector) - ) - masked_vector = parameters_subtraction(masked_vector, private_mask) - else: - # The seed for PRG is the secret key 1 of a dropped client. - neighbor_list = list(nid2neighbours[sid2nid[sid]]) - neighbor_list.remove(sid) - - for neighbor_sid in neighbor_list: - shared_key = generate_shared_key( - bytes_to_private_key(secret), - bytes_to_public_key(sid2public_keys[neighbor_sid][0]), - ) - pairwise_mask = pseudo_rand_gen( - shared_key, mod_range, get_parameters_shape(masked_vector) - ) - if sid > neighbor_sid: - masked_vector = parameters_addition(masked_vector, pairwise_mask) - else: - masked_vector = parameters_subtraction(masked_vector, pairwise_mask) - recon_parameters = parameters_mod(masked_vector, mod_range) - # Divide vector by number of clients who have given us their masked vector - # i.e. those participating in final unmask vectors stage - total_weights_factor, recon_parameters = factor_extract(recon_parameters) - if LOG_EXPLAIN: - print(f"Unmasked sum of vectors (quantized): {recon_parameters[0]}") - # recon_parameters = parameters_divide(recon_parameters, total_weights_factor) - aggregated_vector = dequantize( - quantized_parameters=recon_parameters, - clipping_range=clipping_range, - target_range=target_range, - ) - aggregated_vector[0] -= (len(active_sids) - 1) * clipping_range - if LOG_EXPLAIN: - print(f"Unmasked sum of vectors (dequantized): {aggregated_vector[0]}") - print( - f"Aggregate vector using FedAvg: {aggregated_vector[0] / len(active_sids)}" - ) - print( - "########################### Secure Aggregation End ###########################\n\n" - ) - aggregated_parameters = ndarrays_to_parameters(aggregated_vector) - # Update model parameters - parameters.tensors = aggregated_parameters.tensors - parameters.tensor_type = aggregated_parameters.tensor_type - - -def _wrap_workflow_with_sec_agg( - parameters: Parameters, sampled_node_ids: List[int] -) -> Generator[Dict[int, Task], Dict[int, Task], None]: - return workflow_with_sec_agg( - parameters, sampled_node_ids, sec_agg_config=_secure_aggregation_configuration - ) diff --git a/examples/simulation-pytorch/README.md b/examples/simulation-pytorch/README.md index 11b7a3364376..93f9e1acbac7 100644 --- a/examples/simulation-pytorch/README.md +++ b/examples/simulation-pytorch/README.md @@ -1,6 +1,6 @@ # Flower Simulation example using PyTorch -This introductory example uses the simulation capabilities of Flower to simulate a large number of clients on a single machine. Take a look at the [Documentation](https://flower.dev/docs/framework/how-to-run-simulations.html) for a deep dive into how Flower simulation works. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) to download, partition and preprocess the MNIST dataset. This examples uses 100 clients by default. +This introductory example uses the simulation capabilities of Flower to simulate a large number of clients on a single machine. Take a look at the [Documentation](https://flower.ai/docs/framework/how-to-run-simulations.html) for a deep dive into how Flower simulation works. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) to download, partition and preprocess the MNIST dataset. This examples uses 100 clients by default. ## Running the example (via Jupyter Notebook) @@ -54,17 +54,13 @@ Write the command below in your terminal to install the dependencies according t pip install -r requirements.txt ``` -### Run Federated Learning Example +### Run with `start_simulation()` -```bash -# You can run the example without activating your environemnt -poetry run python sim.py +Ensure you have activated your environment then: -# Or by first activating it -poetry shell +```bash # and then run the example python sim.py -# you can exit your environment by typing "exit" ``` You can adjust the CPU/GPU resources you assign to each of your virtual clients. By default, your clients will only use 1xCPU core. For example: @@ -73,10 +69,29 @@ You can adjust the CPU/GPU resources you assign to each of your virtual clients. # Will assign 2xCPUs to each client python sim.py --num_cpus=2 -# Will assign 2xCPUs and 20% of the GPU's VRAM to each client -# This means that you can have 5 concurrent clients on each GPU +# Will assign 2xCPUs and 25% of the GPU's VRAM to each client +# This means that you can have 4 concurrent clients on each GPU # (assuming you have enough CPUs) -python sim.py --num_cpus=2 --num_gpus=0.2 +python sim.py --num_cpus=2 --num_gpus=0.25 +``` + +### Run with Flower Next (preview) + +Ensure you have activated your environment, then execute the command below. All `ClientApp` instances will run on CPU but the `ServerApp` will run on the GPU if one is available. Note that this is the case because the `Simulation Engine` only exposes certain resources to the `ClientApp` (based on the `client_resources` in `--backend-config`). + +```bash +# Run with the default backend-config. +# `--server-app` points to the `server` object in the sim.py file in this example. +# `--client-app` points to the `client` object in the sim.py file in this example. +flower-simulation --client-app=sim:client --server-app=sim:server --num-supernodes=100 +``` + +You can change the default resources assigned to each `ClientApp` by means of the `--backend-config` argument: + +```bash +# Tells the VCE to reserve 2x CPUs and 25% of available VRAM for each ClientApp +flower-simulation --client-app=sim:client --server-app=sim:server --num-supernodes=100 \ + --backend-config='{"client_resources": {"num_cpus":2, "num_gpus":0.25}}' ``` -Take a look at the [Documentation](https://flower.dev/docs/framework/how-to-run-simulations.html) for more details on how you can customise your simulation. +Take a look at the [Documentation](https://flower.ai/docs/framework/how-to-run-simulations.html) for more details on how you can customise your simulation. diff --git a/examples/simulation-pytorch/pyproject.toml b/examples/simulation-pytorch/pyproject.toml index 07918c0cd17c..5978c17f2c60 100644 --- a/examples/simulation-pytorch/pyproject.toml +++ b/examples/simulation-pytorch/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" name = "simulation-pytorch" version = "0.1.0" description = "Federated Learning Simulation with Flower and PyTorch" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" @@ -17,4 +17,3 @@ torchvision = "0.16.1" [tool.poetry.group.dev.dependencies] ipykernel = "^6.27.0" - diff --git a/examples/simulation-pytorch/sim.ipynb b/examples/simulation-pytorch/sim.ipynb index 508630cf9422..6dda1ef9319d 100644 --- a/examples/simulation-pytorch/sim.ipynb +++ b/examples/simulation-pytorch/sim.ipynb @@ -30,7 +30,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We will be using the _simulation_ mode in Flower, which allows you to run a large number of clients without the overheads of manually managing devices. This is achieved via the [Virtual Client Engine](https://flower.dev/docs/framework/how-to-run-simulations.html) in Flower. With simulation, you can dynamically scale your experiments whether you run the code on your laptop, a machine with a single GPU, a server with multiple GPUs os even on a cluster with multiple servers. The `Virtual Client Engine` handles everything transparently and it allows you to specify how many resources (e.g. CPU cores, GPU VRAM) should be assigned to each virtual client." + "We will be using the _simulation_ mode in Flower, which allows you to run a large number of clients without the overheads of manually managing devices. This is achieved via the [Virtual Client Engine](https://flower.ai/docs/framework/how-to-run-simulations.html) in Flower. With simulation, you can dynamically scale your experiments whether you run the code on your laptop, a machine with a single GPU, a server with multiple GPUs os even on a cluster with multiple servers. The `Virtual Client Engine` handles everything transparently and it allows you to specify how many resources (e.g. CPU cores, GPU VRAM) should be assigned to each virtual client." ] }, { @@ -178,7 +178,7 @@ "\n", "To start designing a Federated Learning pipeline we need to meet one of the key properties in FL: each client has its own data partition. To accomplish this with the MNIST dataset, we are going to generate N random partitions, where N is the total number of clients in our FL system.\n", "\n", - "We can use [Flower Datasets](https://flower.dev/docs/datasets/) to effortlessly obtain an off-the-shelf partitioned dataset or partition one that isn't pre-partitioned. Let's choose MNIST." + "We can use [Flower Datasets](https://flower.ai/docs/datasets/) to effortlessly obtain an off-the-shelf partitioned dataset or partition one that isn't pre-partitioned. Let's choose MNIST." ] }, { @@ -197,7 +197,7 @@ "# Download MNIST dataset and partition the \"train\" partition (so one can be assigned to each client)\n", "mnist_fds = FederatedDataset(dataset=\"mnist\", partitioners={\"train\": NUM_CLIENTS})\n", "# Let's keep the test set as is, and use it to evaluate the global model on the server\n", - "centralized_testset = mnist_fds.load_full(\"test\")" + "centralized_testset = mnist_fds.load_split(\"test\")" ] }, { @@ -290,7 +290,7 @@ " self.model.to(self.device) # send model to device\n", "\n", " def set_parameters(self, parameters):\n", - " \"\"\"With the model paramters received from the server,\n", + " \"\"\"With the model parameters received from the server,\n", " overwrite the uninitialise model in this class with them.\"\"\"\n", "\n", " params_dict = zip(self.model.state_dict().keys(), parameters)\n", @@ -509,7 +509,7 @@ " valloader = DataLoader(valset.with_transform(apply_transforms), batch_size=32)\n", "\n", " # Create and return client\n", - " return FlowerClient(trainloader, valloader)\n", + " return FlowerClient(trainloader, valloader).to_client()\n", "\n", " return client_fn\n", "\n", @@ -605,11 +605,11 @@ "\n", "Get all resources you need!\n", "\n", - "* **[DOCS]** Our complete documenation: https://flower.dev/docs/\n", - "* **[Examples]** All Flower examples: https://flower.dev/docs/examples/\n", + "* **[DOCS]** Our complete documenation: https://flower.ai/docs/\n", + "* **[Examples]** All Flower examples: https://flower.ai/docs/examples/\n", "* **[VIDEO]** Our Youtube channel: https://www.youtube.com/@flowerlabs\n", "\n", - "Don't forget to join our Slack channel: https://flower.dev/join-slack/\n" + "Don't forget to join our Slack channel: https://flower.ai/join-slack/\n" ] } ], diff --git a/examples/simulation-pytorch/sim.py b/examples/simulation-pytorch/sim.py index 68d9426e83ab..6fb750f2e59c 100644 --- a/examples/simulation-pytorch/sim.py +++ b/examples/simulation-pytorch/sim.py @@ -29,9 +29,9 @@ default=0.0, help="Ratio of GPU memory to assign to a virtual client", ) -parser.add_argument("--num_rounds", type=int, default=10, help="Number of FL rounds.") NUM_CLIENTS = 100 +NUM_ROUNDS = 10 # Flower client, adapted from Pytorch quickstart example @@ -104,7 +104,7 @@ def client_fn(cid: str) -> fl.client.Client: valset = valset.with_transform(apply_transforms) # Create and return client - return FlowerClient(trainset, valset) + return FlowerClient(trainset, valset).to_client() return client_fn @@ -167,28 +167,36 @@ def evaluate( return evaluate +# Download MNIST dataset and partition it +mnist_fds = FederatedDataset(dataset="mnist", partitioners={"train": NUM_CLIENTS}) +centralized_testset = mnist_fds.load_split("test") + +# Configure the strategy +strategy = fl.server.strategy.FedAvg( + fraction_fit=0.1, # Sample 10% of available clients for training + fraction_evaluate=0.05, # Sample 5% of available clients for evaluation + min_available_clients=10, + on_fit_config_fn=fit_config, + evaluate_metrics_aggregation_fn=weighted_average, # Aggregate federated metrics + evaluate_fn=get_evaluate_fn(centralized_testset), # Global evaluation function +) + +# ClientApp for Flower-Next +client = fl.client.ClientApp( + client_fn=get_client_fn(mnist_fds), +) + +# ServerApp for Flower-Next +server = fl.server.ServerApp( + config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS), + strategy=strategy, +) + + def main(): # Parse input arguments args = parser.parse_args() - # Download MNIST dataset and partition it - mnist_fds = FederatedDataset(dataset="mnist", partitioners={"train": NUM_CLIENTS}) - centralized_testset = mnist_fds.load_full("test") - - # Configure the strategy - strategy = fl.server.strategy.FedAvg( - fraction_fit=0.1, # Sample 10% of available clients for training - fraction_evaluate=0.05, # Sample 5% of available clients for evaluation - min_fit_clients=10, # Never sample less than 10 clients for training - min_evaluate_clients=5, # Never sample less than 5 clients for evaluation - min_available_clients=int( - NUM_CLIENTS * 0.75 - ), # Wait until at least 75 clients are available - on_fit_config_fn=fit_config, - evaluate_metrics_aggregation_fn=weighted_average, # Aggregate federated metrics - evaluate_fn=get_evaluate_fn(centralized_testset), # Global evaluation function - ) - # Resources to be assigned to each virtual client client_resources = { "num_cpus": args.num_cpus, @@ -200,7 +208,7 @@ def main(): client_fn=get_client_fn(mnist_fds), num_clients=NUM_CLIENTS, client_resources=client_resources, - config=fl.server.ServerConfig(num_rounds=args.num_rounds), + config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS), strategy=strategy, actor_kwargs={ "on_actor_init_fn": disable_progress_bar # disable tqdm on each actor/process spawning virtual clients diff --git a/examples/simulation-tensorflow/README.md b/examples/simulation-tensorflow/README.md index f0d94f343d37..917d7b34c7af 100644 --- a/examples/simulation-tensorflow/README.md +++ b/examples/simulation-tensorflow/README.md @@ -1,6 +1,6 @@ # Flower Simulation example using TensorFlow/Keras -This introductory example uses the simulation capabilities of Flower to simulate a large number of clients on a single machine. Take a look at the [Documentation](https://flower.dev/docs/framework/how-to-run-simulations.html) for a deep dive into how Flower simulation works. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) to download, partition and preprocess the MNIST dataset. This examples uses 100 clients by default. +This introductory example uses the simulation capabilities of Flower to simulate a large number of clients on a single machine. Take a look at the [Documentation](https://flower.ai/docs/framework/how-to-run-simulations.html) for a deep dive into how Flower simulation works. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) to download, partition and preprocess the MNIST dataset. This examples uses 100 clients by default. ## Running the example (via Jupyter Notebook) @@ -53,29 +53,46 @@ Write the command below in your terminal to install the dependencies according t pip install -r requirements.txt ``` -### Run Federated Learning Example +### Run with `start_simulation()` -```bash -# You can run the example without activating your environemnt -poetry run python sim.py +Ensure you have activated your environment then: -# Or by first activating it -poetry shell +```bash # and then run the example python sim.py -# you can exit your environment by typing "exit" ``` -You can adjust the CPU/GPU resources you assign to each of your virtual clients. By default, your clients will only use 1xCPU core. For example: +You can adjust the CPU/GPU resources you assign to each of your virtual clients. By default, your clients will only use 2xCPU core. For example: ```bash # Will assign 2xCPUs to each client python sim.py --num_cpus=2 -# Will assign 2xCPUs and 20% of the GPU's VRAM to each client -# This means that you can have 5 concurrent clients on each GPU +# Will assign 2xCPUs and 25% of the GPU's VRAM to each client +# This means that you can have 4 concurrent clients on each GPU # (assuming you have enough CPUs) -python sim.py --num_cpus=2 --num_gpus=0.2 +python sim.py --num_cpus=2 --num_gpus=0.25 +``` + +Because TensorFlow by default maps all the available VRAM, we need to [enable GPU memory growth](https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth), see how it is done in the example (`sim.py`) for both the "main" process (where the server/strategy runs) and for the clients (using the `actor_kwargs`) + +### Run with Flower Next (preview) + +Ensure you have activated your environment, then execute the command below. All `ClientApp` instances will run on CPU but the `ServerApp` will run on the GPU if one is available. Note that this is the case because the `Simulation Engine` only exposes certain resources to the `ClientApp` (based on the `client_resources` in `--backend-config`). For TensorFlow simulations, it is desirable to make use of TF's [memory growth](https://www.tensorflow.org/api_docs/python/tf/config/experimental/set_memory_growth) feature. You can enable that easily with the `--enable-tf-gpu-growth` flag. + +```bash +# Run with the default backend-config. +# `--server-app` points to the `server` object in the sim.py file in this example. +# `--client-app` points to the `client` object in the sim.py file in this example. +flower-simulation --client-app=sim:client --server-app=sim:server --num-supernodes=100 --enable-tf-gpu-growth +``` + +You can change the default resources assigned to each `ClientApp` using the `--backend-config` argument. + +```bash +# Tells the VCE to reserve 2x CPUs and 25% of available VRAM for each ClientApp +flower-simulation --client-app=sim:client --server-app=sim:server --num-supernodes=100 \ + --backend-config='{"client_resources": {"num_cpus":2, "num_gpus":0.25}}' --enable-tf-gpu-growth ``` -Take a look at the [Documentation](https://flower.dev/docs/framework/how-to-run-simulations.html) for more details on how you can customise your simulation. +Take a look at the [Documentation](https://flower.ai/docs/framework/how-to-run-simulations.html) for more details on how you can customise your simulation. diff --git a/examples/simulation-tensorflow/pyproject.toml b/examples/simulation-tensorflow/pyproject.toml index f2e7bd3006c0..ad8cc2032b2d 100644 --- a/examples/simulation-tensorflow/pyproject.toml +++ b/examples/simulation-tensorflow/pyproject.toml @@ -6,11 +6,11 @@ build-backend = "poetry.core.masonry.api" name = "simulation-tensorflow" version = "0.1.0" description = "Federated Learning Simulation with Flower and Tensorflow" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" flwr = { extras = ["simulation"], version = ">=1.0,<2.0" } flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" } -tensorflow = {version = "^2.9.1, !=2.11.1", markers="platform_machine == 'x86_64'"} -tensorflow-macos = {version = "^2.9.1, !=2.11.1", markers="sys_platform == 'darwin' and platform_machine == 'arm64'"} +tensorflow = { version = "^2.9.1, !=2.11.1", markers = "platform_machine == 'x86_64'" } +tensorflow-macos = { version = "^2.9.1, !=2.11.1", markers = "sys_platform == 'darwin' and platform_machine == 'arm64'" } diff --git a/examples/simulation-tensorflow/sim.ipynb b/examples/simulation-tensorflow/sim.ipynb index 575b437018f3..797e2dcc603e 100644 --- a/examples/simulation-tensorflow/sim.ipynb +++ b/examples/simulation-tensorflow/sim.ipynb @@ -189,7 +189,7 @@ " )\n", "\n", " # Create and return client\n", - " return FlowerClient(trainset, valset)\n", + " return FlowerClient(trainset, valset).to_client()\n", "\n", " return client_fn\n", "\n", @@ -232,7 +232,7 @@ "\n", "Flower comes with a number of built-in strategies, but we can also use our own strategy implementations to customize nearly all aspects of the federated learning approach. For this example, we use the built-in `FedAvg` implementation and customize it using a few basic parameters. The last step is the actual call to `start_simulation` which - you guessed it - actually starts the simulation.\n", "\n", - "We can use [Flower Datasets](https://flower.dev/docs/datasets/) to effortlessly obtain an off-the-shelf partitioned dataset or partition one that isn't pre-partitioned. Let's choose MNIST." + "We can use [Flower Datasets](https://flower.ai/docs/datasets/) to effortlessly obtain an off-the-shelf partitioned dataset or partition one that isn't pre-partitioned. Let's choose MNIST." ] }, { @@ -247,7 +247,7 @@ "# Download MNIST dataset and partition it\n", "mnist_fds = FederatedDataset(dataset=\"mnist\", partitioners={\"train\": NUM_CLIENTS})\n", "# Get the whole test set for centralised evaluation\n", - "centralized_testset = mnist_fds.load_full(\"test\").to_tf_dataset(\n", + "centralized_testset = mnist_fds.load_split(\"test\").to_tf_dataset(\n", " columns=\"image\", label_cols=\"label\", batch_size=64\n", ")\n", "\n", @@ -323,11 +323,11 @@ "\n", "Get all resources you need!\n", "\n", - "* **[DOCS]** Our complete documenation: https://flower.dev/docs/\n", - "* **[Examples]** All Flower examples: https://flower.dev/docs/examples/\n", + "* **[DOCS]** Our complete documenation: https://flower.ai/docs/\n", + "* **[Examples]** All Flower examples: https://flower.ai/docs/examples/\n", "* **[VIDEO]** Our Youtube channel: https://www.youtube.com/@flowerlabs\n", "\n", - "Don't forget to join our Slack channel: https://flower.dev/join-slack/" + "Don't forget to join our Slack channel: https://flower.ai/join-slack/" ] } ], diff --git a/examples/simulation-tensorflow/sim.py b/examples/simulation-tensorflow/sim.py index 490e25fe8c8d..e94e5ec96850 100644 --- a/examples/simulation-tensorflow/sim.py +++ b/examples/simulation-tensorflow/sim.py @@ -1,5 +1,4 @@ import os -import math import argparse from typing import Dict, List, Tuple @@ -29,9 +28,9 @@ default=0.0, help="Ratio of GPU memory to assign to a virtual client", ) -parser.add_argument("--num_rounds", type=int, default=10, help="Number of FL rounds.") NUM_CLIENTS = 100 +NUM_ROUNDS = 10 VERBOSE = 0 @@ -94,7 +93,7 @@ def client_fn(cid: str) -> fl.client.Client: ) # Create and return client - return FlowerClient(trainset, valset) + return FlowerClient(trainset, valset).to_client() return client_fn @@ -129,30 +128,39 @@ def evaluate( return evaluate +# Download MNIST dataset and partition it +mnist_fds = FederatedDataset(dataset="mnist", partitioners={"train": NUM_CLIENTS}) +# Get the whole test set for centralised evaluation +centralized_testset = mnist_fds.load_split("test").to_tf_dataset( + columns="image", label_cols="label", batch_size=64 +) + +# Create FedAvg strategy +strategy = fl.server.strategy.FedAvg( + fraction_fit=0.1, # Sample 10% of available clients for training + fraction_evaluate=0.05, # Sample 5% of available clients for evaluation + min_fit_clients=10, # Never sample less than 10 clients for training + evaluate_metrics_aggregation_fn=weighted_average, # aggregates federated metrics + evaluate_fn=get_evaluate_fn(centralized_testset), # global evaluation function +) + + +# ClientApp for Flower-Next +client = fl.client.ClientApp( + client_fn=get_client_fn(mnist_fds), +) + +# ServerApp for Flower-Next +server = fl.server.ServerApp( + config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS), + strategy=strategy, +) + + def main() -> None: # Parse input arguments args = parser.parse_args() - # Download MNIST dataset and partition it - mnist_fds = FederatedDataset(dataset="mnist", partitioners={"train": NUM_CLIENTS}) - # Get the whole test set for centralised evaluation - centralized_testset = mnist_fds.load_full("test").to_tf_dataset( - columns="image", label_cols="label", batch_size=64 - ) - - # Create FedAvg strategy - strategy = fl.server.strategy.FedAvg( - fraction_fit=0.1, # Sample 10% of available clients for training - fraction_evaluate=0.05, # Sample 5% of available clients for evaluation - min_fit_clients=10, # Never sample less than 10 clients for training - min_evaluate_clients=5, # Never sample less than 5 clients for evaluation - min_available_clients=int( - NUM_CLIENTS * 0.75 - ), # Wait until at least 75 clients are available - evaluate_metrics_aggregation_fn=weighted_average, # aggregates federated metrics - evaluate_fn=get_evaluate_fn(centralized_testset), # global evaluation function - ) - # With a dictionary, you tell Flower's VirtualClientEngine that each # client needs exclusive access to these many resources in order to run client_resources = { @@ -164,7 +172,7 @@ def main() -> None: fl.simulation.start_simulation( client_fn=get_client_fn(mnist_fds), num_clients=NUM_CLIENTS, - config=fl.server.ServerConfig(num_rounds=args.num_rounds), + config=fl.server.ServerConfig(NUM_ROUNDS), strategy=strategy, client_resources=client_resources, actor_kwargs={ diff --git a/examples/sklearn-logreg-mnist/README.md b/examples/sklearn-logreg-mnist/README.md index 79ed63a64233..12b1a5e3bc1a 100644 --- a/examples/sklearn-logreg-mnist/README.md +++ b/examples/sklearn-logreg-mnist/README.md @@ -1,7 +1,7 @@ # Flower Example using scikit-learn This example of Flower uses `scikit-learn`'s `LogisticRegression` model to train a federated learning system. It will help you understand how to adapt Flower for use with `scikit-learn`. -Running this example in itself is quite easy. +Running this example in itself is quite easy. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) to download, partition and preprocess the MNIST dataset. ## Project Setup @@ -57,18 +57,24 @@ Afterwards you are ready to start the Flower server as well as the clients. You poetry run python3 server.py ``` -Now you are ready to start the Flower clients which will participate in the learning. To do so simply open two more terminals and run the following command in each: +Now you are ready to start the Flower clients which will participate in the learning. To do so simply open two or more terminals and run the following command in each: + +Start client 1 in the first terminal: ```shell -poetry run python3 client.py +python3 client.py --partition-id 0 # or any integer in {0-9} ``` -Alternatively you can run all of it in one shell as follows: +Start client 2 in the second terminal: ```shell -poetry run python3 server.py & -poetry run python3 client.py & -poetry run python3 client.py +python3 client.py --partition-id 1 # or any integer in {0-9} +``` + +Alternatively, you can run all of it in one shell as follows: + +```bash +bash run.sh ``` You will see that Flower is starting a federated training. diff --git a/examples/sklearn-logreg-mnist/client.py b/examples/sklearn-logreg-mnist/client.py index dbf0f2f462a7..1e9349df1acc 100644 --- a/examples/sklearn-logreg-mnist/client.py +++ b/examples/sklearn-logreg-mnist/client.py @@ -1,19 +1,35 @@ +import argparse import warnings -import flwr as fl -import numpy as np from sklearn.linear_model import LogisticRegression from sklearn.metrics import log_loss +import flwr as fl import utils +from flwr_datasets import FederatedDataset if __name__ == "__main__": - # Load MNIST dataset from https://www.openml.org/d/554 - (X_train, y_train), (X_test, y_test) = utils.load_mnist() + N_CLIENTS = 10 + + parser = argparse.ArgumentParser(description="Flower") + parser.add_argument( + "--partition-id", + type=int, + choices=range(0, N_CLIENTS), + required=True, + help="Specifies the artificial data partition", + ) + args = parser.parse_args() + partition_id = args.partition_id + + # Load the partition data + fds = FederatedDataset(dataset="mnist", partitioners={"train": N_CLIENTS}) - # Split train set into 10 partitions and randomly use one for training. - partition_id = np.random.choice(10) - (X_train, y_train) = utils.partition(X_train, y_train, 10)[partition_id] + dataset = fds.load_partition(partition_id, "train").with_format("numpy") + X, y = dataset["image"].reshape((len(dataset), -1)), dataset["label"] + # Split the on edge data: 80% train, 20% test + X_train, X_test = X[: int(0.8 * len(X))], X[int(0.8 * len(X)) :] + y_train, y_test = y[: int(0.8 * len(y))], y[int(0.8 * len(y)) :] # Create LogisticRegression Model model = LogisticRegression( @@ -46,4 +62,6 @@ def evaluate(self, parameters, config): # type: ignore return loss, len(X_test), {"accuracy": accuracy} # Start Flower client - fl.client.start_numpy_client(server_address="0.0.0.0:8080", client=MnistClient()) + fl.client.start_client( + server_address="0.0.0.0:8080", client=MnistClient().to_client() + ) diff --git a/examples/sklearn-logreg-mnist/pyproject.toml b/examples/sklearn-logreg-mnist/pyproject.toml index 7c13b3f3d492..58cc5ca4a02e 100644 --- a/examples/sklearn-logreg-mnist/pyproject.toml +++ b/examples/sklearn-logreg-mnist/pyproject.toml @@ -7,13 +7,13 @@ name = "sklearn-mnist" version = "0.1.0" description = "Federated learning with scikit-learn and Flower" authors = [ - "The Flower Authors ", - "Kaushik Amar Das " + "The Flower Authors ", + "Kaushik Amar Das ", ] [tool.poetry.dependencies] python = "^3.8" -flwr = "^1.0.0" +flwr = ">=1.0,<2.0" # flwr = { path = "../../", develop = true } # Development +flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" } scikit-learn = "^1.1.1" -openml = "^0.12.2" diff --git a/examples/sklearn-logreg-mnist/requirements.txt b/examples/sklearn-logreg-mnist/requirements.txt index eec2e1a3c4bd..50da9ace3630 100644 --- a/examples/sklearn-logreg-mnist/requirements.txt +++ b/examples/sklearn-logreg-mnist/requirements.txt @@ -1,4 +1,4 @@ -flwr~=1.4.0 +flwr>=1.0, <2.0 +flwr-datasets[vision]>=0.0.2, <1.0.0 numpy~=1.21.1 -openml~=0.13.1 scikit_learn~=1.2.2 diff --git a/examples/sklearn-logreg-mnist/run.sh b/examples/sklearn-logreg-mnist/run.sh index c64f362086aa..f770ca05f8f4 100755 --- a/examples/sklearn-logreg-mnist/run.sh +++ b/examples/sklearn-logreg-mnist/run.sh @@ -1,15 +1,17 @@ #!/bin/bash +set -e +cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/ echo "Starting server" python server.py & sleep 3 # Sleep for 3s to give the server enough time to start -for i in `seq 0 1`; do +for i in $(seq 0 1); do echo "Starting client $i" - python client.py & + python client.py --partition-id "${i}" & done # This will allow you to use CTRL+C to stop all background processes -trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM +trap 'trap - SIGTERM && kill -- -$$' SIGINT SIGTERM # Wait for all background processes to complete wait diff --git a/examples/sklearn-logreg-mnist/server.py b/examples/sklearn-logreg-mnist/server.py index 77e7a89dd668..e0af91fabcee 100644 --- a/examples/sklearn-logreg-mnist/server.py +++ b/examples/sklearn-logreg-mnist/server.py @@ -4,6 +4,8 @@ from sklearn.linear_model import LogisticRegression from typing import Dict +from flwr_datasets import FederatedDataset + def fit_round(server_round: int) -> Dict: """Send round number to client.""" @@ -14,7 +16,9 @@ def get_evaluate_fn(model: LogisticRegression): """Return an evaluation function for server-side evaluation.""" # Load test data here to avoid the overhead of doing it in `evaluate` itself - _, (X_test, y_test) = utils.load_mnist() + fds = FederatedDataset(dataset="mnist", partitioners={"train": 10}) + dataset = fds.load_split("test").with_format("numpy") + X_test, y_test = dataset["image"].reshape((len(dataset), -1)), dataset["label"] # The `evaluate` function will be called after every round def evaluate(server_round, parameters: fl.common.NDArrays, config): diff --git a/examples/sklearn-logreg-mnist/utils.py b/examples/sklearn-logreg-mnist/utils.py index 6a6d6c12ac73..b279a0d1a4b3 100644 --- a/examples/sklearn-logreg-mnist/utils.py +++ b/examples/sklearn-logreg-mnist/utils.py @@ -1,16 +1,11 @@ -from typing import Tuple, Union, List import numpy as np from sklearn.linear_model import LogisticRegression -import openml -XY = Tuple[np.ndarray, np.ndarray] -Dataset = Tuple[XY, XY] -LogRegParams = Union[XY, Tuple[np.ndarray]] -XYList = List[XY] +from flwr.common import NDArrays -def get_model_parameters(model: LogisticRegression) -> LogRegParams: - """Returns the paramters of a sklearn LogisticRegression model.""" +def get_model_parameters(model: LogisticRegression) -> NDArrays: + """Returns the parameters of a sklearn LogisticRegression model.""" if model.fit_intercept: params = [ model.coef_, @@ -23,9 +18,7 @@ def get_model_parameters(model: LogisticRegression) -> LogRegParams: return params -def set_model_params( - model: LogisticRegression, params: LogRegParams -) -> LogisticRegression: +def set_model_params(model: LogisticRegression, params: NDArrays) -> LogisticRegression: """Sets the parameters of a sklean LogisticRegression model.""" model.coef_ = params[0] if model.fit_intercept: @@ -47,32 +40,3 @@ def set_initial_params(model: LogisticRegression): model.coef_ = np.zeros((n_classes, n_features)) if model.fit_intercept: model.intercept_ = np.zeros((n_classes,)) - - -def load_mnist() -> Dataset: - """Loads the MNIST dataset using OpenML. - - OpenML dataset link: https://www.openml.org/d/554 - """ - mnist_openml = openml.datasets.get_dataset(554) - Xy, _, _, _ = mnist_openml.get_data(dataset_format="array") - X = Xy[:, :-1] # the last column contains labels - y = Xy[:, -1] - # First 60000 samples consist of the train set - x_train, y_train = X[:60000], y[:60000] - x_test, y_test = X[60000:], y[60000:] - return (x_train, y_train), (x_test, y_test) - - -def shuffle(X: np.ndarray, y: np.ndarray) -> XY: - """Shuffle X and y.""" - rng = np.random.default_rng() - idx = rng.permutation(len(X)) - return X[idx], y[idx] - - -def partition(X: np.ndarray, y: np.ndarray, num_partitions: int) -> XYList: - """Split X and y into a number of partitions.""" - return list( - zip(np.array_split(X, num_partitions), np.array_split(y, num_partitions)) - ) diff --git a/examples/vertical-fl/README.md b/examples/vertical-fl/README.md index d5ab0ab9c30d..78588180d3d6 100644 --- a/examples/vertical-fl/README.md +++ b/examples/vertical-fl/README.md @@ -295,7 +295,7 @@ class ServerModel(nn.Module): It comprises a single linear layer that accepts the concatenated outputs from all client models as its input. The number of inputs to this layer equals the -total number of outputs from the client models ( $3 \times 4 = 12$ ). After processing +total number of outputs from the client models (3 x 4 = 12). After processing these inputs, the linear layer's output is passed through a sigmoid activation function (`nn.Sigmoid()`), which maps the result to a `(0, 1)` range, providing a probability score indicative of the likelihood of survival. diff --git a/examples/vertical-fl/pyproject.toml b/examples/vertical-fl/pyproject.toml index 14771c70062f..19dcd0e7a842 100644 --- a/examples/vertical-fl/pyproject.toml +++ b/examples/vertical-fl/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" name = "vertical-fl" version = "0.1.0" description = "PyTorch Vertical FL with Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" diff --git a/examples/vit-finetune/README.md b/examples/vit-finetune/README.md new file mode 100644 index 000000000000..ac1652acf02d --- /dev/null +++ b/examples/vit-finetune/README.md @@ -0,0 +1,94 @@ +# Federated finetuning of a ViT + +This example shows how to use Flower's Simulation Engine to federate the finetuning of a Vision Transformer ([ViT-Base-16](https://pytorch.org/vision/main/models/generated/torchvision.models.vit_b_16.html#torchvision.models.vit_b_16)) that has been pretrained on ImageNet. To keep things simple we'll be finetuning it to [Oxford Flower-102](https://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html) datasset, creating 20 partitions using [Flower Datasets](https://flower.ai/docs/datasets/). We'll be finetuning just the exit `head` of the ViT, this means that the training is not that costly and each client requires just ~1GB of VRAM (for a batch size of 32 images). + +## Running the example + +If you haven't cloned the Flower repository already you might want to clone code example and discard the rest. We prepared a single-line command that you can copy into your shell which will checkout the example for you: + +```shell +git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/vit-finetune . && rm -rf flower && cd vit-finetune +``` + +This will create a new directory called `vit-finetune` containing the following files: + +``` +-- README.md <- Your're reading this right now +-- main.py <- Main file that launches the simulation +-- client.py <- Contains Flower client code and ClientApp +-- server.py <- Contains Flower server code and ServerApp +-- model.py <- Defines model and train/eval functions +-- dataset.py <- Downloads, partitions and processes dataset +-- pyproject.toml <- Example dependencies, installable using Poetry +-- requirements.txt <- Example dependencies, installable using pip +``` + +### Installing Dependencies + +Project dependencies (such as `torch` and `flwr`) are defined in `pyproject.toml` and `requirements.txt`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)) or [pip](https://pip.pypa.io/en/latest/development/), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences. + +#### Poetry + +```shell +poetry install +poetry shell +``` + +#### pip + +With an activated environemnt, install the dependencies for this example: + +```shell +pip install -r requirements.txt +``` + +### Run with `start_simulation()` + +Running the example is quite straightforward. You can control the number of rounds `--num-rounds` (which defaults to 20). + +```bash +python main.py +``` + +![](_static/central_evaluation.png) + +Running the example as-is on an RTX 3090Ti should take ~15s/round running 5 clients in parallel (plus the _global model_ during centralized evaluation stages) in a single GPU. Note that more clients could fit in VRAM, but since the GPU utilization is high (99%-100%) we are probably better off not doing that (at least in this case). + +You can adjust the `client_resources` passed to `start_simulation()` so more/less clients run at the same time in the GPU. Take a look at the [Documentation](https://flower.ai/docs/framework/how-to-run-simulations.html) for more details on how you can customise your simulation. + +```bash ++---------------------------------------------------------------------------------------+ +| NVIDIA-SMI 535.161.07 Driver Version: 535.161.07 CUDA Version: 12.2 | +|-----------------------------------------+----------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+======================+======================| +| 0 NVIDIA GeForce RTX 3090 Ti Off | 00000000:0B:00.0 Off | Off | +| 44% 74C P2 441W / 450W | 7266MiB / 24564MiB | 100% Default | +| | | N/A | ++-----------------------------------------+----------------------+----------------------+ + ++---------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=======================================================================================| +| 0 N/A N/A 173812 C python 1966MiB | +| 0 N/A N/A 174510 C ray::ClientAppActor.run 1056MiB | +| 0 N/A N/A 174512 C ray::ClientAppActor.run 1056MiB | +| 0 N/A N/A 174513 C ray::ClientAppActor.run 1056MiB | +| 0 N/A N/A 174514 C ray::ClientAppActor.run 1056MiB | +| 0 N/A N/A 174516 C ray::ClientAppActor.run 1056MiB | ++---------------------------------------------------------------------------------------+ +``` + +### Run with Flower Next (preview) + +```bash +flower-simulation \ + --client-app=client:app \ + --server-app=server:app \ + --num-supernodes=20 \ + --backend-config='{"client_resources": {"num_cpus":4, "num_gpus":0.25}}' +``` diff --git a/examples/vit-finetune/_static/central_evaluation.png b/examples/vit-finetune/_static/central_evaluation.png new file mode 100644 index 000000000000..d0d53ba353a1 Binary files /dev/null and b/examples/vit-finetune/_static/central_evaluation.png differ diff --git a/examples/vit-finetune/client.py b/examples/vit-finetune/client.py new file mode 100644 index 000000000000..68d98926feeb --- /dev/null +++ b/examples/vit-finetune/client.py @@ -0,0 +1,82 @@ +import torch +from torch.utils.data import DataLoader + +import flwr +from flwr.client import NumPyClient +from dataset import apply_transforms, get_dataset_with_partitions +from model import get_model, set_parameters, train + + +class FedViTClient(NumPyClient): + + def __init__(self, trainset): + + self.trainset = trainset + self.model = get_model() + + # Determine device + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self.model.to(self.device) # send model to device + + def set_for_finetuning(self): + """Freeze all parameter except those in the final head. + + Only output MLP will be updated by the client and therefore, the only part of + the model that will be federated (hence, communicated back to the server for + aggregation.) + """ + + # Disable gradients for everything + self.model.requires_grad_(False) + # Now enable just for output head + self.model.heads.requires_grad_(True) + + def get_parameters(self, config): + """Get locally updated parameters.""" + finetune_layers = self.model.heads + return [val.cpu().numpy() for _, val in finetune_layers.state_dict().items()] + + def fit(self, parameters, config): + set_parameters(self.model, parameters) + + # Get some info from the config + # Get batchsize and LR set from server + batch_size = config["batch_size"] + lr = config["lr"] + + trainloader = DataLoader( + self.trainset, batch_size=batch_size, num_workers=2, shuffle=True + ) + + # Set optimizer + optimizer = torch.optim.Adam(self.model.parameters(), lr=lr) + # Train locally + avg_train_loss = train( + self.model, trainloader, optimizer, epochs=1, device=self.device + ) + # Return locally-finetuned part of the model + return ( + self.get_parameters(config={}), + len(trainloader.dataset), + {"train_loss": avg_train_loss}, + ) + + +# Downloads and partition dataset +federated_ox_flowers, _ = get_dataset_with_partitions(num_partitions=20) + + +def client_fn(cid: str): + """Return a FedViTClient that trains with the cid-th data partition.""" + + trainset_for_this_client = federated_ox_flowers.load_partition(int(cid), "train") + + trainset = trainset_for_this_client.with_transform(apply_transforms) + + return FedViTClient(trainset).to_client() + + +# To be used with Flower Next +app = flwr.client.ClientApp( + client_fn=client_fn, +) diff --git a/examples/vit-finetune/dataset.py b/examples/vit-finetune/dataset.py new file mode 100644 index 000000000000..42e0af560a17 --- /dev/null +++ b/examples/vit-finetune/dataset.py @@ -0,0 +1,52 @@ +from torchvision.transforms import ( + Compose, + Normalize, + ToTensor, + RandomResizedCrop, + Resize, + CenterCrop, +) + +from flwr_datasets import FederatedDataset + + +def get_dataset_with_partitions(num_partitions: int): + """Get Oxford Flowers datasets and partition it. + + Return partitioned dataset as well as the whole test set. + """ + + # Get Oxford Flowers-102 and divide it into 20 IID partitions + ox_flowers_fds = FederatedDataset( + dataset="nelorth/oxford-flowers", partitioners={"train": num_partitions} + ) + + centralized_testset = ox_flowers_fds.load_split("test") + return ox_flowers_fds, centralized_testset + + +def apply_eval_transforms(batch): + """Apply a very standard set of image transforms.""" + transforms = Compose( + [ + Resize((256, 256)), + CenterCrop((224, 224)), + ToTensor(), + Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + batch["image"] = [transforms(img) for img in batch["image"]] + return batch + + +def apply_transforms(batch): + """Apply a very standard set of image transforms.""" + transforms = Compose( + [ + RandomResizedCrop((224, 224)), + ToTensor(), + Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + batch["image"] = [transforms(img) for img in batch["image"]] + return batch diff --git a/examples/vit-finetune/main.py b/examples/vit-finetune/main.py new file mode 100644 index 000000000000..1257246304a1 --- /dev/null +++ b/examples/vit-finetune/main.py @@ -0,0 +1,58 @@ +import argparse + +import flwr as fl +import matplotlib.pyplot as plt + +from server import strategy +from client import client_fn + +parser = argparse.ArgumentParser( + description="Finetuning of a ViT with Flower Simulation." +) + +parser.add_argument( + "--num-rounds", + type=int, + default=20, + help="Number of rounds.", +) + + +def main(): + + args = parser.parse_args() + + # To control the degree of parallelism + # With default settings in this example, + # each client should take just ~1GB of VRAM. + client_resources = { + "num_cpus": 4, + "num_gpus": 0.2, + } + + # Launch simulation + history = fl.simulation.start_simulation( + client_fn=client_fn, + num_clients=20, + client_resources=client_resources, + config=fl.server.ServerConfig(num_rounds=args.num_rounds), + strategy=strategy, + ) + + print(history) + + # Basic plotting + global_accuracy_centralised = history.metrics_centralized["accuracy"] + round = [int(data[0]) for data in global_accuracy_centralised] + acc = [100.0 * data[1] for data in global_accuracy_centralised] + plt.plot(round, acc) + plt.xticks(round) + plt.grid() + plt.ylabel("Accuracy (%)") + plt.xlabel("Round") + plt.title("Federated finetuning of ViT for Flowers-102") + plt.savefig("central_evaluation.png") + + +if __name__ == "__main__": + main() diff --git a/examples/vit-finetune/model.py b/examples/vit-finetune/model.py new file mode 100644 index 000000000000..ca7dc1cd9864 --- /dev/null +++ b/examples/vit-finetune/model.py @@ -0,0 +1,71 @@ +from collections import OrderedDict + +import torch +from torchvision.models import vit_b_16, ViT_B_16_Weights + + +def get_model(): + """Return a pretrained ViT with all layers frozen except output head.""" + + # Instantiate a pre-trained ViT-B on ImageNet + model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1) + + # We're going to federated the finetuning of this model + # using the Oxford Flowers-102 dataset. One easy way to achieve + # this is by re-initializing the output block of the ViT so it + # outputs 102 clases instead of the default 1k + in_features = model.heads[-1].in_features + model.heads[-1] = torch.nn.Linear(in_features, 102) + + # Disable gradients for everything + model.requires_grad_(False) + # Now enable just for output head + model.heads.requires_grad_(True) + + return model + + +def set_parameters(model, parameters): + """Apply the parameters to the model. + + Recall this example only federates the head of the ViT so that's the only part of + the model we need to load. + """ + finetune_layers = model.heads + params_dict = zip(finetune_layers.state_dict().keys(), parameters) + state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) + finetune_layers.load_state_dict(state_dict, strict=True) + + +def train(net, trainloader, optimizer, epochs, device): + """Train the model on the training set.""" + criterion = torch.nn.CrossEntropyLoss() + net.train() + avg_loss = 0 + # A very standard training loop for image classification + for _ in range(epochs): + for batch in trainloader: + images, labels = batch["image"].to(device), batch["label"].to(device) + optimizer.zero_grad() + loss = criterion(net(images), labels) + avg_loss += loss.item() / labels.shape[0] + loss.backward() + optimizer.step() + + return avg_loss / len(trainloader) + + +def test(net, testloader, device: str): + """Validate the network on the entire test set.""" + criterion = torch.nn.CrossEntropyLoss() + correct, loss = 0, 0.0 + net.eval() + with torch.no_grad(): + for data in testloader: + images, labels = data["image"].to(device), data["label"].to(device) + outputs = net(images) + loss += criterion(outputs, labels).item() + _, predicted = torch.max(outputs.data, 1) + correct += (predicted == labels).sum().item() + accuracy = correct / len(testloader.dataset) + return loss, accuracy diff --git a/examples/vit-finetune/pyproject.toml b/examples/vit-finetune/pyproject.toml new file mode 100644 index 000000000000..d014d6b6fb2a --- /dev/null +++ b/examples/vit-finetune/pyproject.toml @@ -0,0 +1,17 @@ +[build-system] +requires = ["poetry-core>=1.4.0"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "vit-finetune" +version = "0.1.0" +description = "FL finetuning of a Vision Transformer with Flower." +authors = ["The Flower Authors "] + +[tool.poetry.dependencies] +python = ">=3.8,<3.11" +flwr = { extras = ["simulation"], version = ">=1.0,<2.0" } +flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" } +torch = "2.2.1" +torchvision = "0.17.1" +matplotlib = "3.8.3" diff --git a/examples/vit-finetune/requirements.txt b/examples/vit-finetune/requirements.txt new file mode 100644 index 000000000000..3692be0d6c2c --- /dev/null +++ b/examples/vit-finetune/requirements.txt @@ -0,0 +1,5 @@ +flwr[simulation]>=1.0, <2.0 +flwr-datasets[vision]>=0.0.2, <1.0.0 +matplotlib==3.8.3 +torch==2.2.1 +torchvision==0.17.1 \ No newline at end of file diff --git a/examples/vit-finetune/server.py b/examples/vit-finetune/server.py new file mode 100644 index 000000000000..698bcd45cece --- /dev/null +++ b/examples/vit-finetune/server.py @@ -0,0 +1,61 @@ +import torch +from datasets import Dataset +from torch.utils.data import DataLoader +import flwr as fl + +from dataset import apply_eval_transforms, get_dataset_with_partitions +from model import get_model, set_parameters, test + + +def fit_config(server_round: int): + """Return a configuration with static batch size and (local) epochs.""" + config = { + "lr": 0.01, # Learning rate used by clients + "batch_size": 32, # Batch size to use by clients during fit() + } + return config + + +def get_evaluate_fn( + centralized_testset: Dataset, +): + """Return an evaluation function for centralized evaluation.""" + + def evaluate(server_round, parameters, config): + """Use the entire Oxford Flowers-102 test set for evaluation.""" + + # Determine device + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + model = get_model() + set_parameters(model, parameters) + model.to(device) + + # Apply transform to dataset + testset = centralized_testset.with_transform(apply_eval_transforms) + + testloader = DataLoader(testset, batch_size=128) + # Run evaluation + loss, accuracy = test(model, testloader, device=device) + + return loss, {"accuracy": accuracy} + + return evaluate + + +# Downloads and partition dataset +_, centralized_testset = get_dataset_with_partitions(num_partitions=20) + +# Configure the strategy +strategy = fl.server.strategy.FedAvg( + fraction_fit=0.5, # Sample 50% of available clients for training each round + fraction_evaluate=0.0, # No federated evaluation + on_fit_config_fn=fit_config, + evaluate_fn=get_evaluate_fn(centralized_testset), # Global evaluation function +) + +# To be used with Flower Next +app = fl.server.ServerApp( + config=fl.server.ServerConfig(num_rounds=3), + strategy=strategy, +) diff --git a/examples/whisper-federated-finetuning/README.md b/examples/whisper-federated-finetuning/README.md index e89a09519fed..ddebe51247b2 100644 --- a/examples/whisper-federated-finetuning/README.md +++ b/examples/whisper-federated-finetuning/README.md @@ -110,7 +110,7 @@ An overview of the FL pipeline built with Flower for this example is illustrated 3. Once on-site training is completed, each client sends back the (now updated) classification head to the Flower server. 4. The Flower server aggregates (via FedAvg) the classification heads in order to obtain a new _global_ classification head. This head will be shared with clients in the next round. -Flower supports two ways of doing Federated Learning: simulated and non-simulated FL. The former, managed by the [`VirtualClientEngine`](https://flower.dev/docs/framework/how-to-run-simulations.html), allows you to run large-scale workloads in a system-aware manner, that scales with the resources available on your system (whether it is a laptop, a desktop with a single GPU, or a cluster of GPU servers). The latter is better suited for settings where clients are unique devices (e.g. a server, a smart device, etc). This example shows you how to use both. +Flower supports two ways of doing Federated Learning: simulated and non-simulated FL. The former, managed by the [`VirtualClientEngine`](https://flower.ai/docs/framework/how-to-run-simulations.html), allows you to run large-scale workloads in a system-aware manner, that scales with the resources available on your system (whether it is a laptop, a desktop with a single GPU, or a cluster of GPU servers). The latter is better suited for settings where clients are unique devices (e.g. a server, a smart device, etc). This example shows you how to use both. ### Preparing the dataset @@ -147,7 +147,7 @@ INFO flwr 2023-11-08 14:03:57,557 | app.py:229 | app_fit: metrics_centralized {' With just 5 FL rounds, the global model should be reaching ~95% validation accuracy. A test accuracy of 97% can be reached with 10 rounds of FL training using the default hyperparameters. On an RTX 3090Ti, each round takes ~20-30s depending on the amount of data the clients selected in a round have. -Take a look at the [Documentation](https://flower.dev/docs/framework/how-to-run-simulations.html) for more details on how you can customize your simulation. +Take a look at the [Documentation](https://flower.ai/docs/framework/how-to-run-simulations.html) for more details on how you can customize your simulation. ### Federated Finetuning (non-simulated) diff --git a/examples/whisper-federated-finetuning/client.py b/examples/whisper-federated-finetuning/client.py index 2bfeadfbdae6..d3bb217933f8 100644 --- a/examples/whisper-federated-finetuning/client.py +++ b/examples/whisper-federated-finetuning/client.py @@ -146,13 +146,13 @@ def client_fn(cid: str): return WhisperFlowerClient( full_train_dataset, num_classes, disable_tqdm, compile - ) + ).to_client() return client_fn -def run_client(): - """Run clinet.""" +def main(): + """Run client.""" # Parse input arguments args = parser.parse_args() @@ -174,10 +174,11 @@ def run_client(): client_data_path=CLIENT_DATA, ) - fl.client.start_numpy_client( - server_address=f"{args.server_address}:8080", client=client_fn(args.cid) + fl.client.start_client( + server_address=f"{args.server_address}:8080", + client=client_fn(args.cid), ) if __name__ == "__main__": - run_client() + main() diff --git a/examples/whisper-federated-finetuning/pyproject.toml b/examples/whisper-federated-finetuning/pyproject.toml index dd5578b8b3d0..27a89578c5a0 100644 --- a/examples/whisper-federated-finetuning/pyproject.toml +++ b/examples/whisper-federated-finetuning/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" name = "whisper-flower" version = "0.1.0" description = "On-device Federated Downstreaming for Speech Classification" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" @@ -16,4 +16,4 @@ tokenizers = "0.13.3" datasets = "2.14.6" soundfile = "0.12.1" librosa = "0.10.1" -# this example was tested with pytorch 2.1.0 \ No newline at end of file +# this example was tested with pytorch 2.1.0 diff --git a/examples/whisper-federated-finetuning/requirements.txt b/examples/whisper-federated-finetuning/requirements.txt index eb4a5d7eb47b..f16b3d6993ce 100644 --- a/examples/whisper-federated-finetuning/requirements.txt +++ b/examples/whisper-federated-finetuning/requirements.txt @@ -3,5 +3,4 @@ tokenizers==0.13.3 datasets==2.14.6 soundfile==0.12.1 librosa==0.10.1 -flwr==1.5.0 -ray==2.6.3 \ No newline at end of file +flwr[simulation]>=1.0, <2.0 \ No newline at end of file diff --git a/examples/xgboost-comprehensive/README.md b/examples/xgboost-comprehensive/README.md index 11c4c3f9a08b..dc6d7e3872d6 100644 --- a/examples/xgboost-comprehensive/README.md +++ b/examples/xgboost-comprehensive/README.md @@ -1,7 +1,7 @@ # Flower Example using XGBoost (Comprehensive) This example demonstrates a comprehensive federated learning setup using Flower with XGBoost. -We use [HIGGS](https://archive.ics.uci.edu/dataset/280/higgs) dataset to perform a binary classification task. +We use [HIGGS](https://archive.ics.uci.edu/dataset/280/higgs) dataset to perform a binary classification task. This examples uses [Flower Datasets](https://flower.ai/docs/datasets/) to retrieve, partition and preprocess the data for each Flower client. It differs from the [xgboost-quickstart](https://github.com/adap/flower/tree/main/examples/xgboost-quickstart) example in the following ways: - Arguments parsers of server and clients for hyperparameters selection. @@ -10,6 +10,32 @@ It differs from the [xgboost-quickstart](https://github.com/adap/flower/tree/mai - Customised partitioner type (uniform, linear, square, exponential). - Centralised/distributed evaluation. - Bagging/cyclic training methods. +- You can run it with Flower Simulation + +## Training Strategies + +This example provides two training strategies, **bagging aggregation** and **cyclic training**. + +### Bagging Aggregation + +Bagging (bootstrap) aggregation is an ensemble meta-algorithm in machine learning, +used for enhancing the stability and accuracy of machine learning algorithms. +Here, we leverage this algorithm for XGBoost trees. + +Specifically, each client is treated as a bootstrap by random subsampling (data partitioning in FL). +At each FL round, all clients boost a number of trees (in this example, 1 tree) based on the local bootstrap samples. +Then, the clients' trees are aggregated on the server, and concatenates them to the global model from previous round. +The aggregated tree ensemble is regarded as a new global model. + +This way, let's consider a scenario with M clients. +Given FL round R, the bagging models consist of (M * R) trees. + +### Cyclic Training + +Cyclic XGBoost training performs FL in a client-by-client fashion. +Instead of aggregating multiple clients, +there is only one single client participating in the training per round in the cyclic training scenario. +The trained local XGBoost trees will be passed to the next client as an initialised model for next round's boosting. ## Project Setup @@ -26,7 +52,10 @@ This will create a new directory called `xgboost-comprehensive` containing the f -- server.py <- Defines the server-side logic -- client.py <- Defines the client-side logic -- dataset.py <- Defines the functions of data loading and partitioning --- utils.py <- Defines the arguments parser for clients and server +-- utils.py <- Defines the arguments parser and hyper-parameters +-- client_utils.py <- Defines the client utility functions +-- server_utils.py <- Defines the server utility functions +-- sim.py <- Example of using Flower simulation -- run_bagging.sh <- Commands to run bagging experiments -- run_cyclic.sh <- Commands to run cyclic experiments -- pyproject.toml <- Example dependencies (if you use Poetry) @@ -47,7 +76,7 @@ poetry shell Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command: ```shell -poetry run python3 -c "import flwr" +poetry run python -c "import flwr" ``` If you don't see any errors you're good to go! @@ -62,44 +91,76 @@ pip install -r requirements.txt ## Run Federated Learning with XGBoost and Flower +You can run this example in two ways: either by manually launching the server, and then several clients that connect to it; or by launching a Flower simulation. Both run the same workload, yielding identical results. The former is ideal for deployments on different machines, while the latter makes it easy to simulate large client cohorts in a resource-aware manner. You can read more about how Flower Simulation works in the [Documentation](https://flower.ai/docs/framework/how-to-run-simulations.html). The commands shown below assume you have activated your environment (if you decide to use Poetry, you can activate it via `poetry shell`). + +### Independent Client/Server Setup + We have two scripts to run bagging and cyclic (client-by-client) experiments. The included `run_bagging.sh` or `run_cyclic.sh` will start the Flower server (using `server.py`), sleep for 15 seconds to ensure that the server is up, and then start 5 Flower clients (using `client.py`) with a small subset of the data from exponential partition distribution. + You can simply start everything in a terminal as follows: ```shell -poetry run ./run_bagging.sh +./run_bagging.sh ``` Or ```shell -poetry run ./run_cyclic.sh +./run_cyclic.sh +``` + +The script starts processes in the background so that you don't have to open six terminal windows. + +You can also run the example without the scripts. First, launch the server: + +```bash +python server.py --train-method=bagging/cyclic --pool-size=N --num-clients-per-round=N +``` + +Then run at least two clients (each on a new terminal or computer in your network) passing different `PARTITION_ID` and all using the same `N` (denoting the total number of clients or data partitions): + +```bash +python client.py --train-method=bagging/cyclic --partition-id=PARTITION_ID --num-partitions=N +``` + +### Flower Simulation Setup + +We also provide an example code (`sim.py`) to use the simulation capabilities of Flower to simulate federated XGBoost training on either a single machine or a cluster of machines. With default arguments, each client will use 2 CPUs. + +To run bagging aggregation with 5 clients for 30 rounds evaluated on centralised test set: + +```shell +python sim.py --train-method=bagging --pool-size=5 --num-clients-per-round=5 --num-rounds=30 --centralised-eval ``` -The script starts processes in the background so that you don't have to open eleven terminal windows. -If you experiment with the code example and something goes wrong, simply using `CTRL + C` on Linux (or `CMD + C` on macOS) wouldn't normally kill all these processes, -which is why the script ends with `trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM EXIT` and `wait`. -This simply allows you to stop the experiment using `CTRL + C` (or `CMD + C`). -If you change the script and anything goes wrong you can still use `killall python` (or `killall python3`) -to kill all background processes (or a more specific command if you have other Python processes running that you don't want to kill). +To run cyclic training with 5 clients for 30 rounds evaluated on centralised test set: -You can also manually run `poetry run python3 server.py --train-method=bagging/cyclic --pool-size=N --num-clients-per-round=N` -and `poetry run python3 client.py --train-method=bagging/cyclic --node-id=NODE_ID --num-partitions=N` for as many clients as you want, -but you have to make sure that each command is run in a different terminal window (or a different computer on the network). +```shell +python sim.py --train-method=cyclic --pool-size=5 --num-rounds=30 --centralised-eval-client +``` In addition, we provide more options to customise the experimental settings, including data partitioning and centralised/distributed evaluation (see `utils.py`). -Look at the [code](https://github.com/adap/flower/tree/main/examples/xgboost-comprehensive) -and [tutorial](https://flower.dev/docs/framework/tutorial-quickstart-xgboost.html) for a detailed explanation. +Check the [tutorial](https://flower.ai/docs/framework/tutorial-quickstart-xgboost.html) for a detailed explanation. ### Expected Experimental Results #### Bagging aggregation experiment -![](_static/xgboost_flower_auc.png) +![](_static/xgboost_flower_auc_bagging.png) -The figure above shows the centralised tested AUC performance over FL rounds on 4 experimental settings. +The figure above shows the centralised tested AUC performance over FL rounds with bagging aggregation strategy on 4 experimental settings. One can see that all settings obtain stable performance boost over FL rounds (especially noticeable at the start of training). -As expected, uniform client distribution shows higher AUC values (beyond 83% at the end) than square/exponential setup. -Feel free to explore more interesting experiments by yourself! +As expected, uniform client distribution shows higher AUC values than square/exponential setup. + +#### Cyclic training experiment + +![](_static/xgboost_flower_auc_cyclic.png) + +This figure shows the cyclic training results on centralised test set. +The models with cyclic training requires more rounds to converge +because only a single client participate in the training per round. + +Feel free to explore more interesting experiments by yourself ! diff --git a/examples/xgboost-comprehensive/_static/xgboost_flower_auc.png b/examples/xgboost-comprehensive/_static/xgboost_flower_auc.png deleted file mode 100644 index e6a4bfb83250..000000000000 Binary files a/examples/xgboost-comprehensive/_static/xgboost_flower_auc.png and /dev/null differ diff --git a/examples/xgboost-comprehensive/_static/xgboost_flower_auc_bagging.png b/examples/xgboost-comprehensive/_static/xgboost_flower_auc_bagging.png new file mode 100644 index 000000000000..e192df214471 Binary files /dev/null and b/examples/xgboost-comprehensive/_static/xgboost_flower_auc_bagging.png differ diff --git a/examples/xgboost-comprehensive/_static/xgboost_flower_auc_cyclic.png b/examples/xgboost-comprehensive/_static/xgboost_flower_auc_cyclic.png new file mode 100644 index 000000000000..731d0fc3fbbc Binary files /dev/null and b/examples/xgboost-comprehensive/_static/xgboost_flower_auc_cyclic.png differ diff --git a/examples/xgboost-comprehensive/client.py b/examples/xgboost-comprehensive/client.py index ff7a4adf7977..2d54c3fd63c7 100644 --- a/examples/xgboost-comprehensive/client.py +++ b/examples/xgboost-comprehensive/client.py @@ -1,21 +1,9 @@ import warnings from logging import INFO -import xgboost as xgb import flwr as fl from flwr_datasets import FederatedDataset from flwr.common.logger import log -from flwr.common import ( - Code, - EvaluateIns, - EvaluateRes, - FitIns, - FitRes, - GetParametersIns, - GetParametersRes, - Parameters, - Status, -) from dataset import ( instantiate_partitioner, @@ -23,7 +11,8 @@ transform_dataset_to_dmatrix, resplit, ) -from utils import client_args_parser, BST_PARAMS +from utils import client_args_parser, BST_PARAMS, NUM_LOCAL_ROUND +from client_utils import XgbClient warnings.filterwarnings("ignore", category=UserWarning) @@ -32,15 +21,13 @@ # Parse arguments for experimental settings args = client_args_parser() -# Load (HIGGS) dataset and conduct partitioning -num_partitions = args.num_partitions - -# Partitioner type is chosen from ["uniform", "linear", "square", "exponential"] -partitioner_type = args.partitioner_type +# Train method (bagging or cyclic) +train_method = args.train_method -# Instantiate partitioner +# Load (HIGGS) dataset and conduct partitioning +# Instantiate partitioner from ["uniform", "linear", "square", "exponential"] partitioner = instantiate_partitioner( - partitioner_type=partitioner_type, num_partitions=num_partitions + partitioner_type=args.partitioner_type, num_partitions=args.num_partitions ) fds = FederatedDataset( dataset="jxie/higgs", @@ -48,25 +35,22 @@ resplitter=resplit, ) -# Load the partition for this `node_id` +# Load the partition for this `partition_id` log(INFO, "Loading partition...") -node_id = args.node_id -partition = fds.load_partition(node_id=node_id, split="train") +partition = fds.load_partition(partition_id=args.partition_id, split="train") partition.set_format("numpy") if args.centralised_eval: # Use centralised test set for evaluation train_data = partition - valid_data = fds.load_full("test") + valid_data = fds.load_split("test") valid_data.set_format("numpy") num_train = train_data.shape[0] num_val = valid_data.shape[0] else: # Train/test splitting - SEED = args.seed - test_fraction = args.test_fraction train_data, valid_data, num_train, num_val = train_test_split( - partition, test_fraction=test_fraction, seed=SEED + partition, test_fraction=args.test_fraction, seed=args.seed ) # Reformat data to DMatrix for xgboost @@ -74,101 +58,25 @@ train_dmatrix = transform_dataset_to_dmatrix(train_data) valid_dmatrix = transform_dataset_to_dmatrix(valid_data) - # Hyper-parameters for xgboost training -num_local_round = 1 +num_local_round = NUM_LOCAL_ROUND params = BST_PARAMS - -# Define Flower client -class XgbClient(fl.client.Client): - def __init__(self): - self.bst = None - self.config = None - - def get_parameters(self, ins: GetParametersIns) -> GetParametersRes: - _ = (self, ins) - return GetParametersRes( - status=Status( - code=Code.OK, - message="OK", - ), - parameters=Parameters(tensor_type="", tensors=[]), - ) - - def _local_boost(self): - # Update trees based on local training data. - for i in range(num_local_round): - self.bst.update(train_dmatrix, self.bst.num_boosted_rounds()) - - # Bagging: extract the last N=num_local_round trees for sever aggregation - # Cyclic: return the entire model - bst = ( - self.bst[ - self.bst.num_boosted_rounds() - - num_local_round : self.bst.num_boosted_rounds() - ] - if args.train_method == "bagging" - else self.bst - ) - - return bst - - def fit(self, ins: FitIns) -> FitRes: - if not self.bst: - # First round local training - log(INFO, "Start training at round 1") - bst = xgb.train( - params, - train_dmatrix, - num_boost_round=num_local_round, - evals=[(valid_dmatrix, "validate"), (train_dmatrix, "train")], - ) - self.config = bst.save_config() - self.bst = bst - else: - for item in ins.parameters.tensors: - global_model = bytearray(item) - - # Load global model into booster - self.bst.load_model(global_model) - self.bst.load_config(self.config) - - bst = self._local_boost() - - local_model = bst.save_raw("json") - local_model_bytes = bytes(local_model) - - return FitRes( - status=Status( - code=Code.OK, - message="OK", - ), - parameters=Parameters(tensor_type="", tensors=[local_model_bytes]), - num_examples=num_train, - metrics={}, - ) - - def evaluate(self, ins: EvaluateIns) -> EvaluateRes: - eval_results = self.bst.eval_set( - evals=[(valid_dmatrix, "valid")], - iteration=self.bst.num_boosted_rounds() - 1, - ) - auc = round(float(eval_results.split("\t")[1].split(":")[1]), 4) - - global_round = ins.config["global_round"] - log(INFO, f"AUC = {auc} at round {global_round}") - - return EvaluateRes( - status=Status( - code=Code.OK, - message="OK", - ), - loss=0.0, - num_examples=num_val, - metrics={"AUC": auc}, - ) - +# Setup learning rate +if args.train_method == "bagging" and args.scaled_lr: + new_lr = params["eta"] / args.num_partitions + params.update({"eta": new_lr}) # Start Flower client -fl.client.start_client(server_address="127.0.0.1:8080", client=XgbClient()) +fl.client.start_client( + server_address="127.0.0.1:8080", + client=XgbClient( + train_dmatrix, + valid_dmatrix, + num_train, + num_val, + num_local_round, + params, + train_method, + ), +) diff --git a/examples/xgboost-comprehensive/client_utils.py b/examples/xgboost-comprehensive/client_utils.py new file mode 100644 index 000000000000..d2e07677ef97 --- /dev/null +++ b/examples/xgboost-comprehensive/client_utils.py @@ -0,0 +1,126 @@ +from logging import INFO +import xgboost as xgb + +import flwr as fl +from flwr.common.logger import log +from flwr.common import ( + Code, + EvaluateIns, + EvaluateRes, + FitIns, + FitRes, + GetParametersIns, + GetParametersRes, + Parameters, + Status, +) + + +class XgbClient(fl.client.Client): + def __init__( + self, + train_dmatrix, + valid_dmatrix, + num_train, + num_val, + num_local_round, + params, + train_method, + ): + self.train_dmatrix = train_dmatrix + self.valid_dmatrix = valid_dmatrix + self.num_train = num_train + self.num_val = num_val + self.num_local_round = num_local_round + self.params = params + self.train_method = train_method + + def get_parameters(self, ins: GetParametersIns) -> GetParametersRes: + _ = (self, ins) + return GetParametersRes( + status=Status( + code=Code.OK, + message="OK", + ), + parameters=Parameters(tensor_type="", tensors=[]), + ) + + def _local_boost(self, bst_input): + # Update trees based on local training data. + for i in range(self.num_local_round): + bst_input.update(self.train_dmatrix, bst_input.num_boosted_rounds()) + + # Bagging: extract the last N=num_local_round trees for sever aggregation + # Cyclic: return the entire model + bst = ( + bst_input[ + bst_input.num_boosted_rounds() + - self.num_local_round : bst_input.num_boosted_rounds() + ] + if self.train_method == "bagging" + else bst_input + ) + + return bst + + def fit(self, ins: FitIns) -> FitRes: + global_round = int(ins.config["global_round"]) + if global_round == 1: + # First round local training + bst = xgb.train( + self.params, + self.train_dmatrix, + num_boost_round=self.num_local_round, + evals=[(self.valid_dmatrix, "validate"), (self.train_dmatrix, "train")], + ) + else: + bst = xgb.Booster(params=self.params) + for item in ins.parameters.tensors: + global_model = bytearray(item) + + # Load global model into booster + bst.load_model(global_model) + + # Local training + bst = self._local_boost(bst) + + # Save model + local_model = bst.save_raw("json") + local_model_bytes = bytes(local_model) + + return FitRes( + status=Status( + code=Code.OK, + message="OK", + ), + parameters=Parameters(tensor_type="", tensors=[local_model_bytes]), + num_examples=self.num_train, + metrics={}, + ) + + def evaluate(self, ins: EvaluateIns) -> EvaluateRes: + # Load global model + bst = xgb.Booster(params=self.params) + for para in ins.parameters.tensors: + para_b = bytearray(para) + bst.load_model(para_b) + + # Run evaluation + eval_results = bst.eval_set( + evals=[(self.valid_dmatrix, "valid")], + iteration=bst.num_boosted_rounds() - 1, + ) + auc = round(float(eval_results.split("\t")[1].split(":")[1]), 4) + + global_round = ins.config["global_round"] + log(INFO, f"AUC = {auc} at round {global_round}") + + return EvaluateRes( + status=Status( + code=Code.OK, + message="OK", + ), + loss=0.0, + num_examples=self.num_val, + metrics={"AUC": auc}, + ) diff --git a/examples/xgboost-comprehensive/dataset.py b/examples/xgboost-comprehensive/dataset.py index bcf2e00b30af..94959925f833 100644 --- a/examples/xgboost-comprehensive/dataset.py +++ b/examples/xgboost-comprehensive/dataset.py @@ -39,12 +39,18 @@ def train_test_split(partition: Dataset, test_fraction: float, seed: int): def transform_dataset_to_dmatrix(data: Union[Dataset, DatasetDict]) -> xgb.core.DMatrix: """Transform dataset to DMatrix format for xgboost.""" - x = data["inputs"] - y = data["label"] + x, y = separate_xy(data) new_data = xgb.DMatrix(x, label=y) return new_data +def separate_xy(data: Union[Dataset, DatasetDict]): + """Return outputs of x (data) and y (labels) .""" + x = data["inputs"] + y = data["label"] + return x, y + + def resplit(dataset: DatasetDict) -> DatasetDict: """Increase the quantity of centralised test samples from 500K to 1M.""" return DatasetDict( diff --git a/examples/xgboost-comprehensive/pyproject.toml b/examples/xgboost-comprehensive/pyproject.toml index bbfbb4134b8d..b257801cb420 100644 --- a/examples/xgboost-comprehensive/pyproject.toml +++ b/examples/xgboost-comprehensive/pyproject.toml @@ -6,10 +6,10 @@ build-backend = "poetry.core.masonry.api" name = "xgboost-comprehensive" version = "0.1.0" description = "Federated XGBoost with Flower (comprehensive)" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" -flwr-nightly = ">=1.0,<2.0" +flwr-nightly = { extras = ["simulation"], version = ">=1.7.0,<2.0" } flwr-datasets = ">=0.0.2,<1.0.0" xgboost = ">=2.0.0,<3.0.0" diff --git a/examples/xgboost-comprehensive/requirements.txt b/examples/xgboost-comprehensive/requirements.txt index c37ac2b6ad6d..b5b1d83bcdd1 100644 --- a/examples/xgboost-comprehensive/requirements.txt +++ b/examples/xgboost-comprehensive/requirements.txt @@ -1,3 +1,3 @@ -flwr-nightly>=1.0, <2.0 +flwr[simulation]>=1.7.0, <2.0 flwr-datasets>=0.0.2, <1.0.0 xgboost>=2.0.0, <3.0.0 diff --git a/examples/xgboost-comprehensive/run_bagging.sh b/examples/xgboost-comprehensive/run_bagging.sh index 7920f6bf5e55..a6300b781a06 100755 --- a/examples/xgboost-comprehensive/run_bagging.sh +++ b/examples/xgboost-comprehensive/run_bagging.sh @@ -3,12 +3,12 @@ set -e cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/ echo "Starting server" -python3 server.py --pool-size=5 --num-rounds=50 --num-clients-per-round=5 --centralised-eval & -sleep 15 # Sleep for 15s to give the server enough time to start +python3 server.py --pool-size=5 --num-rounds=30 --num-clients-per-round=5 --centralised-eval & +sleep 30 # Sleep for 30s to give the server enough time to start for i in `seq 0 4`; do echo "Starting client $i" - python3 client.py --node-id=$i --num-partitions=5 --partitioner-type=exponential & + python3 client.py --partition-id=$i --num-partitions=5 --partitioner-type=exponential & done # Enable CTRL+C to stop all background processes diff --git a/examples/xgboost-comprehensive/run_cyclic.sh b/examples/xgboost-comprehensive/run_cyclic.sh index 47e09fd8faef..258bdf2fe0d8 100755 --- a/examples/xgboost-comprehensive/run_cyclic.sh +++ b/examples/xgboost-comprehensive/run_cyclic.sh @@ -8,7 +8,7 @@ sleep 15 # Sleep for 15s to give the server enough time to start for i in `seq 0 4`; do echo "Starting client $i" - python3 client.py --node-id=$i --train-method=cyclic --num-partitions=5 --partitioner-type=exponential --centralised-eval & + python3 client.py --partition-id=$i --train-method=cyclic --num-partitions=5 --partitioner-type=exponential --centralised-eval & done # Enable CTRL+C to stop all background processes diff --git a/examples/xgboost-comprehensive/server.py b/examples/xgboost-comprehensive/server.py index 1cf4ba79fa50..939819641438 100644 --- a/examples/xgboost-comprehensive/server.py +++ b/examples/xgboost-comprehensive/server.py @@ -1,19 +1,19 @@ import warnings -from typing import Dict, List, Optional from logging import INFO -import xgboost as xgb import flwr as fl from flwr.common.logger import log -from flwr.common import Parameters, Scalar from flwr_datasets import FederatedDataset -from flwr.server.strategy import FedXgbBagging -from flwr.server.strategy import FedXgbCyclic -from flwr.server.client_proxy import ClientProxy -from flwr.server.criterion import Criterion -from flwr.server.client_manager import SimpleClientManager - -from utils import server_args_parser, BST_PARAMS +from flwr.server.strategy import FedXgbBagging, FedXgbCyclic + +from utils import server_args_parser +from server_utils import ( + eval_config, + fit_config, + evaluate_metrics_aggregation, + get_evaluate_fn, + CyclicClientManager, +) from dataset import resplit, transform_dataset_to_dmatrix @@ -34,97 +34,11 @@ fds = FederatedDataset( dataset="jxie/higgs", partitioners={"train": 20}, resplitter=resplit ) - test_set = fds.load_full("test") + log(INFO, "Loading centralised test set...") + test_set = fds.load_split("test") test_set.set_format("numpy") test_dmatrix = transform_dataset_to_dmatrix(test_set) -# Hyper-parameters used for initialisation -params = BST_PARAMS - - -def eval_config(rnd: int) -> Dict[str, str]: - """Return a configuration with global epochs.""" - config = { - "global_round": str(rnd), - } - return config - - -def evaluate_metrics_aggregation(eval_metrics): - """Return an aggregated metric (AUC) for evaluation.""" - total_num = sum([num for num, _ in eval_metrics]) - auc_aggregated = ( - sum([metrics["AUC"] * num for num, metrics in eval_metrics]) / total_num - ) - metrics_aggregated = {"AUC": auc_aggregated} - return metrics_aggregated - - -def get_evaluate_fn(test_data): - """Return a function for centralised evaluation.""" - - def evaluate_fn( - server_round: int, parameters: Parameters, config: Dict[str, Scalar] - ): - # If at the first round, skip the evaluation - if server_round == 0: - return 0, {} - else: - bst = xgb.Booster(params=params) - for para in parameters.tensors: - para_b = bytearray(para) - - # Load global model - bst.load_model(para_b) - # Run evaluation - eval_results = bst.eval_set( - evals=[(test_data, "valid")], - iteration=bst.num_boosted_rounds() - 1, - ) - auc = round(float(eval_results.split("\t")[1].split(":")[1]), 4) - log(INFO, f"AUC = {auc} at round {server_round}") - - return 0, {"AUC": auc} - - return evaluate_fn - - -class CyclicClientManager(SimpleClientManager): - """Provides a cyclic client selection rule.""" - - def sample( - self, - num_clients: int, - min_num_clients: Optional[int] = None, - criterion: Optional[Criterion] = None, - ) -> List[ClientProxy]: - """Sample a number of Flower ClientProxy instances.""" - - # Block until at least num_clients are connected. - if min_num_clients is None: - min_num_clients = num_clients - self.wait_for(min_num_clients) - - # Sample clients which meet the criterion - available_cids = list(self.clients) - if criterion is not None: - available_cids = [ - cid for cid in available_cids if criterion.select(self.clients[cid]) - ] - - if num_clients > len(available_cids): - log( - INFO, - "Sampling failed: number of available clients" - " (%s) is less than number of requested clients (%s).", - len(available_cids), - num_clients, - ) - return [] - - # Return all available clients - return [self.clients[cid] for cid in available_cids] - # Define strategy if train_method == "bagging": @@ -137,9 +51,10 @@ def sample( min_evaluate_clients=num_evaluate_clients if not centralised_eval else 0, fraction_evaluate=1.0 if not centralised_eval else 0.0, on_evaluate_config_fn=eval_config, - evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation - if not centralised_eval - else None, + on_fit_config_fn=fit_config, + evaluate_metrics_aggregation_fn=( + evaluate_metrics_aggregation if not centralised_eval else None + ), ) else: # Cyclic training @@ -149,6 +64,7 @@ def sample( fraction_evaluate=1.0, evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation, on_evaluate_config_fn=eval_config, + on_fit_config_fn=fit_config, ) # Start Flower server diff --git a/examples/xgboost-comprehensive/server_utils.py b/examples/xgboost-comprehensive/server_utils.py new file mode 100644 index 000000000000..35a31bd9adac --- /dev/null +++ b/examples/xgboost-comprehensive/server_utils.py @@ -0,0 +1,101 @@ +from typing import Dict, List, Optional +from logging import INFO +import xgboost as xgb +from flwr.common.logger import log +from flwr.common import Parameters, Scalar +from flwr.server.client_manager import SimpleClientManager +from flwr.server.client_proxy import ClientProxy +from flwr.server.criterion import Criterion +from utils import BST_PARAMS + + +def eval_config(rnd: int) -> Dict[str, str]: + """Return a configuration with global epochs.""" + config = { + "global_round": str(rnd), + } + return config + + +def fit_config(rnd: int) -> Dict[str, str]: + """Return a configuration with global epochs.""" + config = { + "global_round": str(rnd), + } + return config + + +def evaluate_metrics_aggregation(eval_metrics): + """Return an aggregated metric (AUC) for evaluation.""" + total_num = sum([num for num, _ in eval_metrics]) + auc_aggregated = ( + sum([metrics["AUC"] * num for num, metrics in eval_metrics]) / total_num + ) + metrics_aggregated = {"AUC": auc_aggregated} + return metrics_aggregated + + +def get_evaluate_fn(test_data): + """Return a function for centralised evaluation.""" + + def evaluate_fn( + server_round: int, parameters: Parameters, config: Dict[str, Scalar] + ): + # If at the first round, skip the evaluation + if server_round == 0: + return 0, {} + else: + bst = xgb.Booster(params=BST_PARAMS) + for para in parameters.tensors: + para_b = bytearray(para) + + # Load global model + bst.load_model(para_b) + # Run evaluation + eval_results = bst.eval_set( + evals=[(test_data, "valid")], + iteration=bst.num_boosted_rounds() - 1, + ) + auc = round(float(eval_results.split("\t")[1].split(":")[1]), 4) + log(INFO, f"AUC = {auc} at round {server_round}") + + return 0, {"AUC": auc} + + return evaluate_fn + + +class CyclicClientManager(SimpleClientManager): + """Provides a cyclic client selection rule.""" + + def sample( + self, + num_clients: int, + min_num_clients: Optional[int] = None, + criterion: Optional[Criterion] = None, + ) -> List[ClientProxy]: + """Sample a number of Flower ClientProxy instances.""" + + # Block until at least num_clients are connected. + if min_num_clients is None: + min_num_clients = num_clients + self.wait_for(min_num_clients) + + # Sample clients which meet the criterion + available_cids = list(self.clients) + if criterion is not None: + available_cids = [ + cid for cid in available_cids if criterion.select(self.clients[cid]) + ] + + if num_clients > len(available_cids): + log( + INFO, + "Sampling failed: number of available clients" + " (%s) is less than number of requested clients (%s).", + len(available_cids), + num_clients, + ) + return [] + + # Return all available clients + return [self.clients[cid] for cid in available_cids] diff --git a/examples/xgboost-comprehensive/sim.py b/examples/xgboost-comprehensive/sim.py new file mode 100644 index 000000000000..c9481f1cdd5d --- /dev/null +++ b/examples/xgboost-comprehensive/sim.py @@ -0,0 +1,188 @@ +import warnings +from logging import INFO +import xgboost as xgb +from tqdm import tqdm + +import flwr as fl +from flwr_datasets import FederatedDataset +from flwr.common.logger import log +from flwr.server.strategy import FedXgbBagging, FedXgbCyclic + +from dataset import ( + instantiate_partitioner, + train_test_split, + transform_dataset_to_dmatrix, + separate_xy, + resplit, +) +from utils import ( + sim_args_parser, + NUM_LOCAL_ROUND, + BST_PARAMS, +) +from server_utils import ( + eval_config, + fit_config, + evaluate_metrics_aggregation, + get_evaluate_fn, + CyclicClientManager, +) +from client_utils import XgbClient + + +warnings.filterwarnings("ignore", category=UserWarning) + + +def get_client_fn( + train_data_list, valid_data_list, train_method, params, num_local_round +): + """Return a function to construct a client. + + The VirtualClientEngine will execute this function whenever a client is sampled by + the strategy to participate. + """ + + def client_fn(cid: str) -> fl.client.Client: + """Construct a FlowerClient with its own dataset partition.""" + x_train, y_train = train_data_list[int(cid)][0] + x_valid, y_valid = valid_data_list[int(cid)][0] + + # Reformat data to DMatrix + train_dmatrix = xgb.DMatrix(x_train, label=y_train) + valid_dmatrix = xgb.DMatrix(x_valid, label=y_valid) + + # Fetch the number of examples + num_train = train_data_list[int(cid)][1] + num_val = valid_data_list[int(cid)][1] + + # Create and return client + return XgbClient( + train_dmatrix, + valid_dmatrix, + num_train, + num_val, + num_local_round, + params, + train_method, + ) + + return client_fn + + +def main(): + # Parse arguments for experimental settings + args = sim_args_parser() + + # Load (HIGGS) dataset and conduct partitioning + partitioner = instantiate_partitioner( + partitioner_type=args.partitioner_type, num_partitions=args.pool_size + ) + fds = FederatedDataset( + dataset="jxie/higgs", + partitioners={"train": partitioner}, + resplitter=resplit, + ) + + # Load centralised test set + if args.centralised_eval or args.centralised_eval_client: + log(INFO, "Loading centralised test set...") + test_data = fds.load_split("test") + test_data.set_format("numpy") + num_test = test_data.shape[0] + test_dmatrix = transform_dataset_to_dmatrix(test_data) + + # Load partitions and reformat data to DMatrix for xgboost + log(INFO, "Loading client local partitions...") + train_data_list = [] + valid_data_list = [] + + # Load and process all client partitions. This upfront cost is amortized soon + # after the simulation begins since clients wont need to preprocess their partition. + for partition_id in tqdm(range(args.pool_size), desc="Extracting client partition"): + # Extract partition for client with partition_id + partition = fds.load_partition(partition_id=partition_id, split="train") + partition.set_format("numpy") + + if args.centralised_eval_client: + # Use centralised test set for evaluation + train_data = partition + num_train = train_data.shape[0] + x_test, y_test = separate_xy(test_data) + valid_data_list.append(((x_test, y_test), num_test)) + else: + # Train/test splitting + train_data, valid_data, num_train, num_val = train_test_split( + partition, test_fraction=args.test_fraction, seed=args.seed + ) + x_valid, y_valid = separate_xy(valid_data) + valid_data_list.append(((x_valid, y_valid), num_val)) + + x_train, y_train = separate_xy(train_data) + train_data_list.append(((x_train, y_train), num_train)) + + # Define strategy + if args.train_method == "bagging": + # Bagging training + strategy = FedXgbBagging( + evaluate_function=( + get_evaluate_fn(test_dmatrix) if args.centralised_eval else None + ), + fraction_fit=(float(args.num_clients_per_round) / args.pool_size), + min_fit_clients=args.num_clients_per_round, + min_available_clients=args.pool_size, + min_evaluate_clients=( + args.num_evaluate_clients if not args.centralised_eval else 0 + ), + fraction_evaluate=1.0 if not args.centralised_eval else 0.0, + on_evaluate_config_fn=eval_config, + on_fit_config_fn=fit_config, + evaluate_metrics_aggregation_fn=( + evaluate_metrics_aggregation if not args.centralised_eval else None + ), + ) + else: + # Cyclic training + strategy = FedXgbCyclic( + fraction_fit=1.0, + min_available_clients=args.pool_size, + fraction_evaluate=1.0, + evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation, + on_evaluate_config_fn=eval_config, + on_fit_config_fn=fit_config, + ) + + # Resources to be assigned to each virtual client + # In this example we use CPU by default + client_resources = { + "num_cpus": args.num_cpus_per_client, + "num_gpus": 0.0, + } + + # Hyper-parameters for xgboost training + num_local_round = NUM_LOCAL_ROUND + params = BST_PARAMS + + # Setup learning rate + if args.train_method == "bagging" and args.scaled_lr: + new_lr = params["eta"] / args.pool_size + params.update({"eta": new_lr}) + + # Start simulation + fl.simulation.start_simulation( + client_fn=get_client_fn( + train_data_list, + valid_data_list, + args.train_method, + params, + num_local_round, + ), + num_clients=args.pool_size, + client_resources=client_resources, + config=fl.server.ServerConfig(num_rounds=args.num_rounds), + strategy=strategy, + client_manager=CyclicClientManager() if args.train_method == "cyclic" else None, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/xgboost-comprehensive/utils.py b/examples/xgboost-comprehensive/utils.py index 8acdbbb88a7e..abc100da1ade 100644 --- a/examples/xgboost-comprehensive/utils.py +++ b/examples/xgboost-comprehensive/utils.py @@ -1,6 +1,8 @@ import argparse +# Hyper-parameters for xgboost training +NUM_LOCAL_ROUND = 1 BST_PARAMS = { "objective": "binary:logistic", "eta": 0.1, # Learning rate @@ -35,10 +37,10 @@ def client_args_parser(): help="Partitioner types.", ) parser.add_argument( - "--node-id", + "--partition-id", default=0, type=int, - help="Node ID used for the current client.", + help="Partition ID used for the current client.", ) parser.add_argument( "--seed", default=42, type=int, help="Seed used for train/test splitting." @@ -52,7 +54,12 @@ def client_args_parser(): parser.add_argument( "--centralised-eval", action="store_true", - help="Conduct centralised evaluation (True), or client evaluation on hold-out data (False).", + help="Conduct evaluation on centralised test set (True), or on hold-out data (False).", + ) + parser.add_argument( + "--scaled-lr", + action="store_true", + help="Perform scaled learning rate based on the number of clients (True).", ) args = parser.parse_args() @@ -96,3 +103,78 @@ def server_args_parser(): args = parser.parse_args() return args + + +def sim_args_parser(): + """Parse arguments to define experimental settings on server side.""" + parser = argparse.ArgumentParser() + + parser.add_argument( + "--train-method", + default="bagging", + type=str, + choices=["bagging", "cyclic"], + help="Training methods selected from bagging aggregation or cyclic training.", + ) + + # Server side + parser.add_argument( + "--pool-size", default=5, type=int, help="Number of total clients." + ) + parser.add_argument( + "--num-rounds", default=30, type=int, help="Number of FL rounds." + ) + parser.add_argument( + "--num-clients-per-round", + default=5, + type=int, + help="Number of clients participate in training each round.", + ) + parser.add_argument( + "--num-evaluate-clients", + default=5, + type=int, + help="Number of clients selected for evaluation.", + ) + parser.add_argument( + "--centralised-eval", + action="store_true", + help="Conduct centralised evaluation (True), or client evaluation on hold-out data (False).", + ) + parser.add_argument( + "--num-cpus-per-client", + default=2, + type=int, + help="Number of CPUs used for per client.", + ) + + # Client side + parser.add_argument( + "--partitioner-type", + default="uniform", + type=str, + choices=["uniform", "linear", "square", "exponential"], + help="Partitioner types.", + ) + parser.add_argument( + "--seed", default=42, type=int, help="Seed used for train/test splitting." + ) + parser.add_argument( + "--test-fraction", + default=0.2, + type=float, + help="Test fraction for train/test splitting.", + ) + parser.add_argument( + "--centralised-eval-client", + action="store_true", + help="Conduct evaluation on centralised test set (True), or on hold-out data (False).", + ) + parser.add_argument( + "--scaled-lr", + action="store_true", + help="Perform scaled learning rate based on the number of clients (True).", + ) + + args = parser.parse_args() + return args diff --git a/examples/xgboost-quickstart/README.md b/examples/xgboost-quickstart/README.md index 5174c236c668..72dde5706e8d 100644 --- a/examples/xgboost-quickstart/README.md +++ b/examples/xgboost-quickstart/README.md @@ -67,13 +67,13 @@ To do so simply open two more terminal windows and run the following commands. Start client 1 in the first terminal: ```shell -python3 client.py --node-id=0 +python3 client.py --partition-id=0 ``` Start client 2 in the second terminal: ```shell -python3 client.py --node-id=1 +python3 client.py --partition-id=1 ``` You will see that XGBoost is starting a federated training. @@ -85,4 +85,4 @@ poetry run ./run.sh ``` Look at the [code](https://github.com/adap/flower/tree/main/examples/xgboost-quickstart) -and [tutorial](https://flower.dev/docs/framework/tutorial-quickstart-xgboost.html) for a detailed explanation. +and [tutorial](https://flower.ai/docs/framework/tutorial-quickstart-xgboost.html) for a detailed explanation. diff --git a/examples/xgboost-quickstart/client.py b/examples/xgboost-quickstart/client.py index b5eab59ba14d..6ac23ae15148 100644 --- a/examples/xgboost-quickstart/client.py +++ b/examples/xgboost-quickstart/client.py @@ -24,13 +24,13 @@ warnings.filterwarnings("ignore", category=UserWarning) -# Define arguments parser for the client/node ID. +# Define arguments parser for the client/partition ID. parser = argparse.ArgumentParser() parser.add_argument( - "--node-id", + "--partition-id", default=0, type=int, - help="Node ID used for the current client.", + help="Partition ID used for the current client.", ) args = parser.parse_args() @@ -61,9 +61,9 @@ def transform_dataset_to_dmatrix(data: Union[Dataset, DatasetDict]) -> xgb.core. partitioner = IidPartitioner(num_partitions=30) fds = FederatedDataset(dataset="jxie/higgs", partitioners={"train": partitioner}) -# Load the partition for this `node_id` +# Load the partition for this `partition_id` log(INFO, "Loading partition...") -partition = fds.load_partition(node_id=args.node_id, split="train") +partition = fds.load_partition(partition_id=args.partition_id, split="train") partition.set_format("numpy") # Train/test splitting @@ -173,4 +173,4 @@ def evaluate(self, ins: EvaluateIns) -> EvaluateRes: # Start Flower client -fl.client.start_client(server_address="127.0.0.1:8080", client=XgbClient()) +fl.client.start_client(server_address="127.0.0.1:8080", client=XgbClient().to_client()) diff --git a/examples/xgboost-quickstart/pyproject.toml b/examples/xgboost-quickstart/pyproject.toml index 7b3cbd9659a2..c16542ea7ffe 100644 --- a/examples/xgboost-quickstart/pyproject.toml +++ b/examples/xgboost-quickstart/pyproject.toml @@ -6,10 +6,10 @@ build-backend = "poetry.core.masonry.api" name = "xgboost-quickstart" version = "0.1.0" description = "Federated XGBoost with Flower (quickstart)" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" -flwr = ">=1.6.0,<2.0" +flwr = ">=1.7.0,<2.0" flwr-datasets = ">=0.0.1,<1.0.0" xgboost = ">=2.0.0,<3.0.0" diff --git a/examples/xgboost-quickstart/requirements.txt b/examples/xgboost-quickstart/requirements.txt index 4ccd5587bfc3..c6949e0651c5 100644 --- a/examples/xgboost-quickstart/requirements.txt +++ b/examples/xgboost-quickstart/requirements.txt @@ -1,3 +1,3 @@ -flwr>=1.6.0, <2.0 +flwr>=1.7.0, <2.0 flwr-datasets>=0.0.1, <1.0.0 xgboost>=2.0.0, <3.0.0 diff --git a/examples/xgboost-quickstart/run.sh b/examples/xgboost-quickstart/run.sh index 6287145bfb5f..b35af58222ab 100755 --- a/examples/xgboost-quickstart/run.sh +++ b/examples/xgboost-quickstart/run.sh @@ -8,7 +8,7 @@ sleep 5 # Sleep for 5s to give the server enough time to start for i in `seq 0 1`; do echo "Starting client $i" - python3 client.py --node-id=$i & + python3 client.py --partition-id=$i & done # Enable CTRL+C to stop all background processes diff --git a/pyproject.toml b/pyproject.toml index 2349d554a409..e0514254ecac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,14 +4,14 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "flwr" -version = "1.7.0" +version = "1.8.0" description = "Flower: A Friendly Federated Learning Framework" license = "Apache-2.0" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] readme = "README.md" -homepage = "https://flower.dev" +homepage = "https://flower.ai" repository = "https://github.com/adap/flower" -documentation = "https://flower.dev" +documentation = "https://flower.ai" keywords = [ "flower", "fl", @@ -52,21 +52,26 @@ exclude = [ ] [tool.poetry.scripts] +flwr = "flwr.cli.app:app" flower-driver-api = "flwr.server:run_driver_api" flower-fleet-api = "flwr.server:run_fleet_api" -flower-server = "flwr.server:run_server" -flower-client = "flwr.client:run_client" +flower-superlink = "flwr.server:run_superlink" +flower-client-app = "flwr.client:run_client_app" +flower-server-app = "flwr.server:run_server_app" +flower-simulation = "flwr.simulation:run_simulation_from_cli" [tool.poetry.dependencies] python = "^3.8" # Mandatory dependencies numpy = "^1.21.0" -grpcio = "^1.48.2,!=1.52.0" -protobuf = "^3.19.0" -cryptography = "^41.0.2" +grpcio = "^1.60.0" +protobuf = "^4.25.2" +cryptography = "^42.0.4" pycryptodome = "^3.18.0" iterators = "^0.0.2" -# Optional dependencies (VCE) +typer = { version = "^0.9.0", extras=["all"] } +tomli = "^2.0.1" +# Optional dependencies (Simulation Engine) ray = { version = "==2.6.3", optional = true } pydantic = { version = "<2.0.0", optional = true } # Optional dependencies (REST transport layer) @@ -81,21 +86,21 @@ rest = ["requests", "starlette", "uvicorn"] [tool.poetry.group.dev.dependencies] types-dataclasses = "==0.6.6" types-protobuf = "==3.19.18" -types-requests = "==2.31.0.2" -types-setuptools = "==68.2.0.0" -clang-format = "==17.0.4" -isort = "==5.12.0" -black = { version = "==23.10.1", extras = ["jupyter"] } +types-requests = "==2.31.0.20240125" +types-setuptools = "==69.0.0.20240125" +clang-format = "==17.0.6" +isort = "==5.13.2" +black = { version = "==24.2.0", extras = ["jupyter"] } docformatter = "==1.7.5" -mypy = "==1.6.1" -pylint = "==2.13.9" +mypy = "==1.8.0" +pylint = "==3.0.3" flake8 = "==5.0.4" -pytest = "==7.4.3" +pytest = "==7.4.4" pytest-cov = "==4.1.0" -pytest-watch = "==4.2.0" -grpcio-tools = "==1.48.2" +pytest-watcher = "==0.4.1" +grpcio-tools = "==1.60.0" mypy-protobuf = "==3.2.0" -jupyterlab = "==4.0.8" +jupyterlab = "==4.0.12" rope = "==1.11.0" semver = "==3.0.2" sphinx = "==6.2.1" @@ -109,7 +114,7 @@ furo = "==2023.9.10" sphinx-reredirects = "==0.1.3" nbsphinx = "==0.9.3" nbstripout = "==0.6.1" -ruff = "==0.1.4" +ruff = "==0.1.9" sphinx-argparse = "==0.4.0" pipreqs = "==0.4.13" mdformat-gfm = "==0.3.5" @@ -120,7 +125,8 @@ twine = "==4.0.2" pyroma = "==4.2" check-wheel-contents = "==0.4.0" GitPython = "==3.1.32" -licensecheck = "==2023.5.1" +PyGithub = "==2.1.1" +licensecheck = "==2024" [tool.isort] line_length = 88 @@ -136,7 +142,7 @@ line-length = 88 target-version = ["py38", "py39", "py310", "py311"] [tool.pylint."MESSAGES CONTROL"] -disable = "bad-continuation,duplicate-code,too-few-public-methods,useless-import-alias" +disable = "duplicate-code,too-few-public-methods,useless-import-alias" [tool.pytest.ini_options] minversion = "6.2" @@ -145,6 +151,16 @@ testpaths = [ "src/py/flwr", "src/py/flwr_tool", ] +filterwarnings = "ignore::DeprecationWarning" + +[tool.pytest-watcher] +now = false +clear = true +delay = 0.2 +runner = "pytest" +runner_args = ["-s", "-vvvvv"] +patterns = ["*.py"] +ignore_patterns = [] [tool.mypy] plugins = [ @@ -183,7 +199,7 @@ target-version = "py38" line-length = 88 select = ["D", "E", "F", "W", "B", "ISC", "C4", "UP"] fixable = ["D", "E", "F", "W", "B", "ISC", "C4", "UP"] -ignore = ["B024", "B027"] +ignore = ["B024", "B027", "D205", "D209"] exclude = [ ".bzr", ".direnv", diff --git a/src/docker/client/Dockerfile b/src/docker/client/Dockerfile new file mode 100644 index 000000000000..0755a7989281 --- /dev/null +++ b/src/docker/client/Dockerfile @@ -0,0 +1,8 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. + +ARG BASE_REPOSITORY=flwr/base +ARG BASE_IMAGE_TAG +FROM $BASE_REPOSITORY:$BASE_IMAGE_TAG + +ARG FLWR_VERSION +RUN python -m pip install -U --no-cache-dir flwr[rest]==${FLWR_VERSION} diff --git a/src/docker/server/Dockerfile b/src/docker/server/Dockerfile index 9bf3214bb42c..faa9cf2e56fe 100644 --- a/src/docker/server/Dockerfile +++ b/src/docker/server/Dockerfile @@ -1,13 +1,14 @@ # Copyright 2023 Flower Labs GmbH. All Rights Reserved. -ARG BASE_IMAGE_VERSION=py3.11-ubuntu22.04 -FROM flwr/base:$BASE_IMAGE_VERSION as server +ARG BASE_REPOSITORY=flwr/base +ARG BASE_IMAGE_TAG=py3.11-ubuntu22.04 +FROM $BASE_REPOSITORY:$BASE_IMAGE_TAG as server WORKDIR /app ARG FLWR_VERSION RUN python -m pip install -U --no-cache-dir flwr[rest]==${FLWR_VERSION} -ENTRYPOINT ["python", "-c", "from flwr.server import run_server; run_server()"] +ENTRYPOINT ["python", "-c", "from flwr.server import run_superlink; run_superlink()"] # Test if Flower can be successfully installed and imported FROM server as test -RUN python -c "from flwr.server import run_server" +RUN python -c "from flwr.server import run_superlink" diff --git a/src/kotlin/flwr/src/main/AndroidManifest.xml b/src/kotlin/flwr/src/main/AndroidManifest.xml index 8bdb7e14b389..3cb3262db448 100644 --- a/src/kotlin/flwr/src/main/AndroidManifest.xml +++ b/src/kotlin/flwr/src/main/AndroidManifest.xml @@ -1,4 +1,5 @@ - + + diff --git a/src/proto/flwr/proto/driver.proto b/src/proto/flwr/proto/driver.proto index eb948217a4de..bc0062c4a51f 100644 --- a/src/proto/flwr/proto/driver.proto +++ b/src/proto/flwr/proto/driver.proto @@ -21,8 +21,8 @@ import "flwr/proto/node.proto"; import "flwr/proto/task.proto"; service Driver { - // Request workload_id - rpc CreateWorkload(CreateWorkloadRequest) returns (CreateWorkloadResponse) {} + // Request run_id + rpc CreateRun(CreateRunRequest) returns (CreateRunResponse) {} // Return a set of nodes rpc GetNodes(GetNodesRequest) returns (GetNodesResponse) {} @@ -34,12 +34,12 @@ service Driver { rpc PullTaskRes(PullTaskResRequest) returns (PullTaskResResponse) {} } -// CreateWorkload -message CreateWorkloadRequest {} -message CreateWorkloadResponse { sint64 workload_id = 1; } +// CreateRun +message CreateRunRequest {} +message CreateRunResponse { sint64 run_id = 1; } // GetNodes messages -message GetNodesRequest { sint64 workload_id = 1; } +message GetNodesRequest { sint64 run_id = 1; } message GetNodesResponse { repeated Node nodes = 1; } // PushTaskIns messages diff --git a/src/proto/flwr/proto/error.proto b/src/proto/flwr/proto/error.proto new file mode 100644 index 000000000000..a35af7f8af67 --- /dev/null +++ b/src/proto/flwr/proto/error.proto @@ -0,0 +1,23 @@ +// Copyright 2024 Flower Labs GmbH. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================== + +syntax = "proto3"; + +package flwr.proto; + +message Error { + sint64 code = 1; + string reason = 2; +} diff --git a/src/proto/flwr/proto/recordset.proto b/src/proto/flwr/proto/recordset.proto new file mode 100644 index 000000000000..d51d0f9ce416 --- /dev/null +++ b/src/proto/flwr/proto/recordset.proto @@ -0,0 +1,76 @@ +// Copyright 2024 Flower Labs GmbH. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================== + +syntax = "proto3"; + +package flwr.proto; + +message DoubleList { repeated double vals = 1; } +message Sint64List { repeated sint64 vals = 1; } +message BoolList { repeated bool vals = 1; } +message StringList { repeated string vals = 1; } +message BytesList { repeated bytes vals = 1; } + +message Array { + string dtype = 1; + repeated int32 shape = 2; + string stype = 3; + bytes data = 4; +} + +message MetricsRecordValue { + oneof value { + // Single element + double double = 1; + sint64 sint64 = 2; + + // List types + DoubleList double_list = 21; + Sint64List sint64_list = 22; + } +} + +message ConfigsRecordValue { + oneof value { + // Single element + double double = 1; + sint64 sint64 = 2; + bool bool = 3; + string string = 4; + bytes bytes = 5; + + // List types + DoubleList double_list = 21; + Sint64List sint64_list = 22; + BoolList bool_list = 23; + StringList string_list = 24; + BytesList bytes_list = 25; + } +} + +message ParametersRecord { + repeated string data_keys = 1; + repeated Array data_values = 2; +} + +message MetricsRecord { map data = 1; } + +message ConfigsRecord { map data = 1; } + +message RecordSet { + map parameters = 1; + map metrics = 2; + map configs = 3; +} diff --git a/src/proto/flwr/proto/task.proto b/src/proto/flwr/proto/task.proto index 2205ef2815c8..423df76f1335 100644 --- a/src/proto/flwr/proto/task.proto +++ b/src/proto/flwr/proto/task.proto @@ -18,7 +18,9 @@ syntax = "proto3"; package flwr.proto; import "flwr/proto/node.proto"; +import "flwr/proto/recordset.proto"; import "flwr/proto/transport.proto"; +import "flwr/proto/error.proto"; message Task { Node producer = 1; @@ -27,48 +29,21 @@ message Task { string delivered_at = 4; string ttl = 5; repeated string ancestry = 6; - SecureAggregation sa = 7; - - ServerMessage legacy_server_message = 101 [ deprecated = true ]; - ClientMessage legacy_client_message = 102 [ deprecated = true ]; + string task_type = 7; + RecordSet recordset = 8; + Error error = 9; } message TaskIns { string task_id = 1; string group_id = 2; - sint64 workload_id = 3; + sint64 run_id = 3; Task task = 4; } message TaskRes { string task_id = 1; string group_id = 2; - sint64 workload_id = 3; + sint64 run_id = 3; Task task = 4; } - -message Value { - message DoubleList { repeated double vals = 1; } - message Sint64List { repeated sint64 vals = 1; } - message BoolList { repeated bool vals = 1; } - message StringList { repeated string vals = 1; } - message BytesList { repeated bytes vals = 1; } - - oneof value { - // Single element - double double = 1; - sint64 sint64 = 2; - bool bool = 3; - string string = 4; - bytes bytes = 5; - - // List types - DoubleList double_list = 21; - Sint64List sint64_list = 22; - BoolList bool_list = 23; - StringList string_list = 24; - BytesList bytes_list = 25; - } -} - -message SecureAggregation { map named_values = 1; } diff --git a/src/py/flwr/__init__.py b/src/py/flwr/__init__.py index e05799280339..ccaf07c6012f 100644 --- a/src/py/flwr/__init__.py +++ b/src/py/flwr/__init__.py @@ -17,13 +17,11 @@ from flwr.common.version import package_version as _package_version -from . import client, common, driver, flower, server, simulation +from . import client, common, server, simulation __all__ = [ "client", "common", - "driver", - "flower", "server", "simulation", ] diff --git a/src/py/flwr/cli/__init__.py b/src/py/flwr/cli/__init__.py new file mode 100644 index 000000000000..d4d3b8ac4d48 --- /dev/null +++ b/src/py/flwr/cli/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower command line interface.""" diff --git a/src/py/flwr/cli/app.py b/src/py/flwr/cli/app.py new file mode 100644 index 000000000000..1c6ee6c97841 --- /dev/null +++ b/src/py/flwr/cli/app.py @@ -0,0 +1,37 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower command line interface.""" + +import typer + +from .example import example +from .new import new +from .run import run + +app = typer.Typer( + help=typer.style( + "flwr is the Flower command line interface.", + fg=typer.colors.BRIGHT_YELLOW, + bold=True, + ), + no_args_is_help=True, +) + +app.command()(new) +app.command()(example) +app.command()(run) + +if __name__ == "__main__": + app() diff --git a/src/py/flwr/cli/example.py b/src/py/flwr/cli/example.py new file mode 100644 index 000000000000..625ca8729640 --- /dev/null +++ b/src/py/flwr/cli/example.py @@ -0,0 +1,64 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower command line interface `example` command.""" + +import json +import os +import subprocess +import tempfile +import urllib.request + +from .utils import prompt_options + + +def example() -> None: + """Clone a Flower example. + + All examples available in the Flower repository are available through this command. + """ + # Load list of examples directly from GitHub + url = "https://api.github.com/repos/adap/flower/git/trees/main" + with urllib.request.urlopen(url) as res: + data = json.load(res) + examples_directory_url = [ + item["url"] for item in data["tree"] if item["path"] == "examples" + ][0] + + with urllib.request.urlopen(examples_directory_url) as res: + data = json.load(res) + example_names = [ + item["path"] for item in data["tree"] if item["path"] not in [".gitignore"] + ] + + example_name = prompt_options( + "Please select example by typing in the number", + example_names, + ) + + with tempfile.TemporaryDirectory() as tmpdirname: + subprocess.check_output( + [ + "git", + "clone", + "--depth=1", + "https://github.com/adap/flower.git", + tmpdirname, + ] + ) + examples_dir = os.path.join(tmpdirname, "examples", example_name) + subprocess.check_output(["mv", examples_dir, "."]) + + print() + print(f"Example ready to use in {os.path.join(os.getcwd(), example_name)}") diff --git a/src/py/flwr/cli/flower_toml.py b/src/py/flwr/cli/flower_toml.py new file mode 100644 index 000000000000..103f83532054 --- /dev/null +++ b/src/py/flwr/cli/flower_toml.py @@ -0,0 +1,140 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility to validate the `flower.toml` file.""" + +import os +from typing import Any, Dict, List, Optional, Tuple + +import tomli + +from flwr.common import object_ref + + +def load_and_validate_with_defaults( + path: Optional[str] = None, +) -> Tuple[Optional[Dict[str, Any]], List[str], List[str]]: + """Load and validate flower.toml as dict. + + Returns + ------- + Tuple[Optional[config], List[str], List[str]] + A tuple with the optional config in case it exists and is valid + and associated errors and warnings. + """ + config = load(path) + + if config is None: + errors = [ + "Project configuration could not be loaded. flower.toml does not exist." + ] + return (None, errors, []) + + is_valid, errors, warnings = validate(config) + + if not is_valid: + return (None, errors, warnings) + + # Apply defaults + defaults = { + "flower": { + "engine": {"name": "simulation", "simulation": {"supernode": {"num": 2}}} + } + } + config = apply_defaults(config, defaults) + + return (config, errors, warnings) + + +def load(path: Optional[str] = None) -> Optional[Dict[str, Any]]: + """Load flower.toml and return as dict.""" + if path is None: + cur_dir = os.getcwd() + toml_path = os.path.join(cur_dir, "flower.toml") + else: + toml_path = path + + if not os.path.isfile(toml_path): + return None + + with open(toml_path, encoding="utf-8") as toml_file: + data = tomli.loads(toml_file.read()) + return data + + +def validate_fields(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]]: + """Validate flower.toml fields.""" + errors = [] + warnings = [] + + if "project" not in config: + errors.append("Missing [project] section") + else: + if "name" not in config["project"]: + errors.append('Property "name" missing in [project]') + if "version" not in config["project"]: + errors.append('Property "version" missing in [project]') + if "description" not in config["project"]: + warnings.append('Recommended property "description" missing in [project]') + if "license" not in config["project"]: + warnings.append('Recommended property "license" missing in [project]') + if "authors" not in config["project"]: + warnings.append('Recommended property "authors" missing in [project]') + + if "flower" not in config: + errors.append("Missing [flower] section") + elif "components" not in config["flower"]: + errors.append("Missing [flower.components] section") + else: + if "serverapp" not in config["flower"]["components"]: + errors.append('Property "serverapp" missing in [flower.components]') + if "clientapp" not in config["flower"]["components"]: + errors.append('Property "clientapp" missing in [flower.components]') + + return len(errors) == 0, errors, warnings + + +def validate(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]]: + """Validate flower.toml.""" + is_valid, errors, warnings = validate_fields(config) + + if not is_valid: + return False, errors, warnings + + # Validate serverapp + is_valid, reason = object_ref.validate(config["flower"]["components"]["serverapp"]) + if not is_valid and isinstance(reason, str): + return False, [reason], [] + + # Validate clientapp + is_valid, reason = object_ref.validate(config["flower"]["components"]["clientapp"]) + + if not is_valid and isinstance(reason, str): + return False, [reason], [] + + return True, [], [] + + +def apply_defaults( + config: Dict[str, Any], + defaults: Dict[str, Any], +) -> Dict[str, Any]: + """Apply defaults to config.""" + for key in defaults: + if key in config: + if isinstance(config[key], dict) and isinstance(defaults[key], dict): + apply_defaults(config[key], defaults[key]) + else: + config[key] = defaults[key] + return config diff --git a/src/py/flwr/cli/flower_toml_test.py b/src/py/flwr/cli/flower_toml_test.py new file mode 100644 index 000000000000..72a52e4e8b9b --- /dev/null +++ b/src/py/flwr/cli/flower_toml_test.py @@ -0,0 +1,284 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test for Flower command line interface `run` command.""" + +import os +import textwrap +from typing import Any, Dict + +from .flower_toml import load, validate, validate_fields + + +def test_load_flower_toml_load_from_cwd(tmp_path: str) -> None: + """Test if load_template returns a string.""" + # Prepare + flower_toml_content = """ + [project] + name = "fedgpt" + + [flower.components] + serverapp = "fedgpt.server:app" + clientapp = "fedgpt.client:app" + + [flower.engine] + name = "simulation" # optional + + [flower.engine.simulation.supernode] + count = 10 # optional + """ + expected_config = { + "project": { + "name": "fedgpt", + }, + "flower": { + "components": { + "serverapp": "fedgpt.server:app", + "clientapp": "fedgpt.client:app", + }, + "engine": { + "name": "simulation", + "simulation": {"supernode": {"count": 10}}, + }, + }, + } + + # Current directory + origin = os.getcwd() + + try: + # Change into the temporary directory + os.chdir(tmp_path) + with open("flower.toml", "w", encoding="utf-8") as f: + f.write(textwrap.dedent(flower_toml_content)) + + # Execute + config = load() + + # Assert + assert config == expected_config + finally: + os.chdir(origin) + + +def test_load_flower_toml_from_path(tmp_path: str) -> None: + """Test if load_template returns a string.""" + # Prepare + flower_toml_content = """ + [project] + name = "fedgpt" + + [flower.components] + serverapp = "fedgpt.server:app" + clientapp = "fedgpt.client:app" + + [flower.engine] + name = "simulation" # optional + + [flower.engine.simulation.supernode] + count = 10 # optional + """ + expected_config = { + "project": { + "name": "fedgpt", + }, + "flower": { + "components": { + "serverapp": "fedgpt.server:app", + "clientapp": "fedgpt.client:app", + }, + "engine": { + "name": "simulation", + "simulation": {"supernode": {"count": 10}}, + }, + }, + } + + # Current directory + origin = os.getcwd() + + try: + # Change into the temporary directory + os.chdir(tmp_path) + with open("flower.toml", "w", encoding="utf-8") as f: + f.write(textwrap.dedent(flower_toml_content)) + + # Execute + config = load(path=os.path.join(tmp_path, "flower.toml")) + + # Assert + assert config == expected_config + finally: + os.chdir(origin) + + +def test_validate_flower_toml_fields_empty() -> None: + """Test that validate_flower_toml_fields fails correctly.""" + # Prepare + config: Dict[str, Any] = {} + + # Execute + is_valid, errors, warnings = validate_fields(config) + + # Assert + assert not is_valid + assert len(errors) == 2 + assert len(warnings) == 0 + + +def test_validate_flower_toml_fields_no_flower() -> None: + """Test that validate_flower_toml_fields fails correctly.""" + # Prepare + config = { + "project": { + "name": "fedgpt", + "version": "1.0.0", + "description": "", + "license": "", + "authors": [], + } + } + + # Execute + is_valid, errors, warnings = validate_fields(config) + + # Assert + assert not is_valid + assert len(errors) == 1 + assert len(warnings) == 0 + + +def test_validate_flower_toml_fields_no_flower_components() -> None: + """Test that validate_flower_toml_fields fails correctly.""" + # Prepare + config = { + "project": { + "name": "fedgpt", + "version": "1.0.0", + "description": "", + "license": "", + "authors": [], + }, + "flower": {}, + } + + # Execute + is_valid, errors, warnings = validate_fields(config) + + # Assert + assert not is_valid + assert len(errors) == 1 + assert len(warnings) == 0 + + +def test_validate_flower_toml_fields_no_server_and_client_app() -> None: + """Test that validate_flower_toml_fields fails correctly.""" + # Prepare + config = { + "project": { + "name": "fedgpt", + "version": "1.0.0", + "description": "", + "license": "", + "authors": [], + }, + "flower": {"components": {}}, + } + + # Execute + is_valid, errors, warnings = validate_fields(config) + + # Assert + assert not is_valid + assert len(errors) == 2 + assert len(warnings) == 0 + + +def test_validate_flower_toml_fields() -> None: + """Test that validate_flower_toml_fields succeeds correctly.""" + # Prepare + config = { + "project": { + "name": "fedgpt", + "version": "1.0.0", + "description": "", + "license": "", + "authors": [], + }, + "flower": {"components": {"serverapp": "", "clientapp": ""}}, + } + + # Execute + is_valid, errors, warnings = validate_fields(config) + + # Assert + assert is_valid + assert len(errors) == 0 + assert len(warnings) == 0 + + +def test_validate_flower_toml() -> None: + """Test that validate_flower_toml succeeds correctly.""" + # Prepare + config = { + "project": { + "name": "fedgpt", + "version": "1.0.0", + "description": "", + "license": "", + "authors": [], + }, + "flower": { + "components": { + "serverapp": "flwr.cli.run:run", + "clientapp": "flwr.cli.run:run", + } + }, + } + + # Execute + is_valid, errors, warnings = validate(config) + + # Assert + assert is_valid + assert not errors + assert not warnings + + +def test_validate_flower_toml_fail() -> None: + """Test that validate_flower_toml fails correctly.""" + # Prepare + config = { + "project": { + "name": "fedgpt", + "version": "1.0.0", + "description": "", + "license": "", + "authors": [], + }, + "flower": { + "components": { + "serverapp": "flwr.cli.run:run", + "clientapp": "flwr.cli.run:runa", + } + }, + } + + # Execute + is_valid, errors, warnings = validate(config) + + # Assert + assert not is_valid + assert len(errors) == 1 + assert len(warnings) == 0 diff --git a/src/py/flwr/flower/__init__.py b/src/py/flwr/cli/new/__init__.py similarity index 72% rename from src/py/flwr/flower/__init__.py rename to src/py/flwr/cli/new/__init__.py index 892a7ce5afdc..a973f47021c3 100644 --- a/src/py/flwr/flower/__init__.py +++ b/src/py/flwr/cli/new/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 Adap GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,15 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Flower callable package.""" +"""Flower command line interface `new` command.""" - -from flwr.client.flower import Flower as Flower -from flwr.client.typing import Bwd as Bwd -from flwr.client.typing import Fwd as Fwd +from .new import new as new __all__ = [ - "Flower", - "Fwd", - "Bwd", + "new", ] diff --git a/src/py/flwr/cli/new/new.py b/src/py/flwr/cli/new/new.py new file mode 100644 index 000000000000..7eb47e3e3548 --- /dev/null +++ b/src/py/flwr/cli/new/new.py @@ -0,0 +1,154 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower command line interface `new` command.""" + +import os +from enum import Enum +from string import Template +from typing import Dict, Optional + +import typer +from typing_extensions import Annotated + +from ..utils import prompt_options, prompt_text + + +class MlFramework(str, Enum): + """Available frameworks.""" + + NUMPY = "NumPy" + PYTORCH = "PyTorch" + TENSORFLOW = "TensorFlow" + + +class TemplateNotFound(Exception): + """Raised when template does not exist.""" + + +def load_template(name: str) -> str: + """Load template from template directory and return as text.""" + tpl_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "templates")) + tpl_file_path = os.path.join(tpl_dir, name) + + if not os.path.isfile(tpl_file_path): + raise TemplateNotFound(f"Template '{name}' not found") + + with open(tpl_file_path, encoding="utf-8") as tpl_file: + return tpl_file.read() + + +def render_template(template: str, data: Dict[str, str]) -> str: + """Render template.""" + tpl_file = load_template(template) + tpl = Template(tpl_file) + result = tpl.substitute(data) + return result + + +def create_file(file_path: str, content: str) -> None: + """Create file including all nessecary directories and write content into file.""" + os.makedirs(os.path.dirname(file_path), exist_ok=True) + with open(file_path, "w", encoding="utf-8") as f: + f.write(content) + + +def render_and_create(file_path: str, template: str, context: Dict[str, str]) -> None: + """Render template and write to file.""" + content = render_template(template, context) + create_file(file_path, content) + + +def new( + project_name: Annotated[ + Optional[str], + typer.Argument(metavar="project_name", help="The name of the project"), + ] = None, + framework: Annotated[ + Optional[MlFramework], + typer.Option(case_sensitive=False, help="The ML framework to use"), + ] = None, +) -> None: + """Create new Flower project.""" + print( + typer.style( + f"🔨 Creating Flower project {project_name}...", + fg=typer.colors.GREEN, + bold=True, + ) + ) + + if project_name is None: + project_name = prompt_text("Please provide project name") + + if framework is not None: + framework_str = str(framework.value) + else: + framework_value = prompt_options( + "Please select ML framework by typing in the number", + [mlf.value for mlf in MlFramework], + ) + selected_value = [ + name + for name, value in vars(MlFramework).items() + if value == framework_value + ] + framework_str = selected_value[0] + + framework_str = framework_str.lower() + + # Set project directory path + cwd = os.getcwd() + pnl = project_name.lower() + project_dir = os.path.join(cwd, pnl) + + # List of files to render + files = { + "README.md": {"template": "app/README.md.tpl"}, + "requirements.txt": {"template": f"app/requirements.{framework_str}.txt.tpl"}, + "flower.toml": {"template": "app/flower.toml.tpl"}, + "pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"}, + f"{pnl}/__init__.py": {"template": "app/code/__init__.py.tpl"}, + f"{pnl}/server.py": {"template": f"app/code/server.{framework_str}.py.tpl"}, + f"{pnl}/client.py": {"template": f"app/code/client.{framework_str}.py.tpl"}, + } + + # In case framework is MlFramework.PYTORCH generate additionally the task.py file + if framework_str == MlFramework.PYTORCH.value.lower(): + files[f"{pnl}/task.py"] = {"template": f"app/code/task.{framework_str}.py.tpl"} + + context = {"project_name": project_name} + + for file_path, value in files.items(): + render_and_create( + file_path=os.path.join(project_dir, file_path), + template=value["template"], + context=context, + ) + + print( + typer.style( + "🎊 Project creation successful.\n\n" + "Use the following command to run your project:\n", + fg=typer.colors.GREEN, + bold=True, + ) + ) + print( + typer.style( + f" cd {project_name}\n" + " pip install -e .\n flwr run\n", + fg=typer.colors.BRIGHT_CYAN, + bold=True, + ) + ) diff --git a/src/py/flwr/cli/new/new_test.py b/src/py/flwr/cli/new/new_test.py new file mode 100644 index 000000000000..cedcb09b7755 --- /dev/null +++ b/src/py/flwr/cli/new/new_test.py @@ -0,0 +1,101 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test for Flower command line interface `new` command.""" + +import os + +from .new import MlFramework, create_file, load_template, new, render_template + + +def test_load_template() -> None: + """Test if load_template returns a string.""" + # Prepare + filename = "app/README.md.tpl" + + # Execute + text = load_template(filename) + + # Assert + assert isinstance(text, str) + + +def test_render_template() -> None: + """Test if a string is correctly substituted.""" + # Prepare + filename = "app/README.md.tpl" + data = {"project_name": "FedGPT"} + + # Execute + result = render_template(filename, data) + + # Assert + assert "# FedGPT" in result + + +def test_create_file(tmp_path: str) -> None: + """Test if file with content is created.""" + # Prepare + file_path = os.path.join(tmp_path, "test.txt") + content = "Foobar" + + # Execute + create_file(file_path, content) + + # Assert + with open(file_path, encoding="utf-8") as f: + text = f.read() + + assert text == "Foobar" + + +def test_new(tmp_path: str) -> None: + """Test if project is created for framework.""" + # Prepare + project_name = "FedGPT" + framework = MlFramework.PYTORCH + expected_files_top_level = { + "requirements.txt", + "fedgpt", + "README.md", + "flower.toml", + "pyproject.toml", + } + expected_files_module = { + "__init__.py", + "server.py", + "client.py", + "task.py", + } + + # Current directory + origin = os.getcwd() + + try: + # Change into the temprorary directory + os.chdir(tmp_path) + + # Execute + new(project_name=project_name, framework=framework) + + # Assert + file_list = os.listdir(os.path.join(tmp_path, project_name.lower())) + assert set(file_list) == expected_files_top_level + + file_list = os.listdir( + os.path.join(tmp_path, project_name.lower(), project_name.lower()) + ) + assert set(file_list) == expected_files_module + finally: + os.chdir(origin) diff --git a/src/py/flwr/cli/new/templates/__init__.py b/src/py/flwr/cli/new/templates/__init__.py new file mode 100644 index 000000000000..7a951c2da1a2 --- /dev/null +++ b/src/py/flwr/cli/new/templates/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower CLI `new` command templates.""" diff --git a/src/py/flwr/cli/new/templates/app/README.md.tpl b/src/py/flwr/cli/new/templates/app/README.md.tpl new file mode 100644 index 000000000000..516bed0f40c2 --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/README.md.tpl @@ -0,0 +1,47 @@ +# $project_name + +## Install dependencies + +```bash +# Using pip +pip install . + +# Or using Poetry +poetry install +``` + +## Run (Simulation Engine) + +In the `$project_name` directory, use `flwr run` to run a local simulation: + +```bash +flwr run +``` + +## Run (Deployment Engine) + +### Start the SuperLink + +```bash +flower-superlink --insecure +``` + +### Start the long-running Flower client + +In a new terminal window, start the first long-running Flower client: + +```bash +flower-client-app client:app --insecure +``` + +In yet another new terminal window, start the second long-running Flower client: + +```bash +flower-client-app client:app --insecure +``` + +### Start the ServerApp + +```bash +flower-server-app server:app --insecure +``` diff --git a/src/py/flwr/cli/new/templates/app/__init__.py b/src/py/flwr/cli/new/templates/app/__init__.py new file mode 100644 index 000000000000..617628fc9138 --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower CLI `new` command app templates.""" diff --git a/src/py/flwr/cli/new/templates/app/code/__init__.py b/src/py/flwr/cli/new/templates/app/code/__init__.py new file mode 100644 index 000000000000..7f1a0e9f4fa2 --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/code/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower CLI `new` command app / code templates.""" diff --git a/src/py/flwr/cli/new/templates/app/code/__init__.py.tpl b/src/py/flwr/cli/new/templates/app/code/__init__.py.tpl new file mode 100644 index 000000000000..57998c81efb8 --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/code/__init__.py.tpl @@ -0,0 +1 @@ +"""$project_name.""" diff --git a/src/py/flwr/cli/new/templates/app/code/client.numpy.py.tpl b/src/py/flwr/cli/new/templates/app/code/client.numpy.py.tpl new file mode 100644 index 000000000000..232c305fc2a9 --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/code/client.numpy.py.tpl @@ -0,0 +1,23 @@ +"""$project_name: A Flower / NumPy app.""" + +from flwr.client import NumPyClient, ClientApp +import numpy as np + + +class FlowerClient(NumPyClient): + def get_parameters(self, config): + return [np.ones((1, 1))] + + def fit(self, parameters, config): + return ([np.ones((1, 1))], 1, {}) + + def evaluate(self, parameters, config): + return float(0.0), 1, {"accuracy": float(1.0)} + + +def client_fn(cid: str): + return FlowerClient().to_client() + + +# Flower ClientApp +app = ClientApp(client_fn=client_fn) diff --git a/src/py/flwr/cli/new/templates/app/code/client.pytorch.py.tpl b/src/py/flwr/cli/new/templates/app/code/client.pytorch.py.tpl new file mode 100644 index 000000000000..7137a7791683 --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/code/client.pytorch.py.tpl @@ -0,0 +1,46 @@ +"""$project_name: A Flower / PyTorch app.""" + +from flwr.client import NumPyClient, ClientApp + +from $project_name.task import ( + Net, + DEVICE, + load_data, + get_weights, + set_weights, + train, + test, +) + + +# Define Flower Client and client_fn +class FlowerClient(NumPyClient): + def __init__(self, net, trainloader, valloader): + self.net = net + self.trainloader = trainloader + self.valloader = valloader + + def fit(self, parameters, config): + set_weights(self.net, parameters) + results = train(self.net, self.trainloader, self.valloader, 1, DEVICE) + return get_weights(self.net), len(self.trainloader.dataset), results + + def evaluate(self, parameters, config): + set_weights(self.net, parameters) + loss, accuracy = test(self.net, self.valloader) + return loss, len(self.valloader.dataset), {"accuracy": accuracy} + + +def client_fn(cid): + # Load model and data + net = Net().to(DEVICE) + trainloader, valloader = load_data(int(cid), 2) + + # Return Client instance + return FlowerClient(net, trainloader, valloader).to_client() + + +# Flower ClientApp +app = ClientApp( + client_fn, +) diff --git a/src/py/flwr/cli/new/templates/app/code/client.tensorflow.py.tpl b/src/py/flwr/cli/new/templates/app/code/client.tensorflow.py.tpl new file mode 100644 index 000000000000..cc00f8ff0b8c --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/code/client.tensorflow.py.tpl @@ -0,0 +1 @@ +"""$project_name: A Flower / TensorFlow app.""" diff --git a/src/py/flwr/cli/new/templates/app/code/server.numpy.py.tpl b/src/py/flwr/cli/new/templates/app/code/server.numpy.py.tpl new file mode 100644 index 000000000000..03f95ae35cfd --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/code/server.numpy.py.tpl @@ -0,0 +1,12 @@ +"""$project_name: A Flower / NumPy app.""" + +import flwr as fl + +# Configure the strategy +strategy = fl.server.strategy.FedAvg() + +# Flower ServerApp +app = fl.server.ServerApp( + config=fl.server.ServerConfig(num_rounds=1), + strategy=strategy, +) diff --git a/src/py/flwr/cli/new/templates/app/code/server.pytorch.py.tpl b/src/py/flwr/cli/new/templates/app/code/server.pytorch.py.tpl new file mode 100644 index 000000000000..cb04c052b429 --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/code/server.pytorch.py.tpl @@ -0,0 +1,28 @@ +"""$project_name: A Flower / PyTorch app.""" + +from flwr.common import ndarrays_to_parameters +from flwr.server import ServerApp, ServerConfig +from flwr.server.strategy import FedAvg + +from $project_name.task import Net, get_weights + + +# Initialize model parameters +ndarrays = get_weights(Net()) +parameters = ndarrays_to_parameters(ndarrays) + + +# Define strategy +strategy = FedAvg( + fraction_fit=1.0, + fraction_evaluate=1.0, + min_available_clients=2, + initial_parameters=parameters, +) + + +# Create ServerApp +app = ServerApp( + config=ServerConfig(num_rounds=3), + strategy=strategy, +) diff --git a/src/py/flwr/cli/new/templates/app/code/server.tensorflow.py.tpl b/src/py/flwr/cli/new/templates/app/code/server.tensorflow.py.tpl new file mode 100644 index 000000000000..cc00f8ff0b8c --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/code/server.tensorflow.py.tpl @@ -0,0 +1 @@ +"""$project_name: A Flower / TensorFlow app.""" diff --git a/src/py/flwr/cli/new/templates/app/code/task.pytorch.py.tpl b/src/py/flwr/cli/new/templates/app/code/task.pytorch.py.tpl new file mode 100644 index 000000000000..85460564b6ef --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/code/task.pytorch.py.tpl @@ -0,0 +1,106 @@ +"""$project_name: A Flower / PyTorch app.""" + +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader +from torchvision.datasets import CIFAR10 +from torchvision.transforms import Compose, Normalize, ToTensor +from flwr_datasets import FederatedDataset + +DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +class Net(nn.Module): + """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" + + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = x.view(-1, 16 * 5 * 5) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + return self.fc3(x) + + +def load_data(partition_id, num_partitions): + """Load partition CIFAR10 data.""" + fds = FederatedDataset(dataset="cifar10", partitioners={"train": num_partitions}) + partition = fds.load_partition(partition_id) + # Divide data on each node: 80% train, 20% test + partition_train_test = partition.train_test_split(test_size=0.2) + pytorch_transforms = Compose( + [ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] + ) + + def apply_transforms(batch): + """Apply transforms to the partition from FederatedDataset.""" + batch["img"] = [pytorch_transforms(img) for img in batch["img"]] + return batch + + partition_train_test = partition_train_test.with_transform(apply_transforms) + trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True) + testloader = DataLoader(partition_train_test["test"], batch_size=32) + return trainloader, testloader + + +def train(net, trainloader, valloader, epochs, device): + """Train the model on the training set.""" + net.to(device) # move model to GPU if available + criterion = torch.nn.CrossEntropyLoss().to(device) + optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9) + net.train() + for _ in range(epochs): + for batch in trainloader: + images = batch["img"] + labels = batch["label"] + optimizer.zero_grad() + criterion(net(images.to(DEVICE)), labels.to(DEVICE)).backward() + optimizer.step() + + train_loss, train_acc = test(net, trainloader) + val_loss, val_acc = test(net, valloader) + + results = { + "train_loss": train_loss, + "train_accuracy": train_acc, + "val_loss": val_loss, + "val_accuracy": val_acc, + } + return results + + +def test(net, testloader): + """Validate the model on the test set.""" + criterion = torch.nn.CrossEntropyLoss() + correct, loss = 0, 0.0 + with torch.no_grad(): + for batch in testloader: + images = batch["img"].to(DEVICE) + labels = batch["label"].to(DEVICE) + outputs = net(images) + loss += criterion(outputs, labels).item() + correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() + accuracy = correct / len(testloader.dataset) + return loss, accuracy + + +def get_weights(net): + return [val.cpu().numpy() for _, val in net.state_dict().items()] + + +def set_weights(net, parameters): + params_dict = zip(net.state_dict().keys(), parameters) + state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) + net.load_state_dict(state_dict, strict=True) diff --git a/src/py/flwr/cli/new/templates/app/flower.toml.tpl b/src/py/flwr/cli/new/templates/app/flower.toml.tpl new file mode 100644 index 000000000000..07a6ffaf9e49 --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/flower.toml.tpl @@ -0,0 +1,13 @@ +[project] +name = "$project_name" +version = "1.0.0" +description = "" +license = "Apache-2.0" +authors = [ + "The Flower Authors ", +] +readme = "README.md" + +[flower.components] +serverapp = "$project_name.server:app" +clientapp = "$project_name.client:app" diff --git a/src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl new file mode 100644 index 000000000000..15d8211a1a25 --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl @@ -0,0 +1,19 @@ +[build-system] +requires = ["poetry-core>=1.4.0"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "$project_name" +version = "1.0.0" +description = "" +license = "Apache-2.0" +authors = [ + "The Flower Authors ", +] +readme = "README.md" + +[tool.poetry.dependencies] +python = "^3.9" +# Mandatory dependencies +numpy = "^1.21.0" +flwr = { version = "^1.8.0", extras = ["simulation"] } diff --git a/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl new file mode 100644 index 000000000000..da0e15b903f8 --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl @@ -0,0 +1,21 @@ +[build-system] +requires = ["poetry-core>=1.4.0"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "$project_name" +version = "1.0.0" +description = "" +license = "Apache-2.0" +authors = [ + "The Flower Authors ", +] +readme = "README.md" + +[tool.poetry.dependencies] +python = "^3.9" +# Mandatory dependencies +flwr-nightly = { version = "1.8.0.dev20240313", extras = ["simulation"] } +flwr-datasets = { version = "0.0.2", extras = ["vision"] } +torch = "2.2.1" +torchvision = "0.17.1" diff --git a/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl new file mode 100644 index 000000000000..f7383a78b7d5 --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl @@ -0,0 +1,21 @@ +[build-system] +requires = ["poetry-core>=1.4.0"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "$project_name" +version = "1.0.0" +description = "" +license = "Apache-2.0" +authors = [ + "The Flower Authors ", +] +readme = "README.md" + +[tool.poetry.dependencies] +python = ">=3.9,<3.11" +# Mandatory dependencies +flwr = { version = "^1.8.0", extras = ["simulation"] } +flwr-datasets = { version = "^0.0.2", extras = ["vision"] } +tensorflow-cpu = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\"" } +tensorflow-macos = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\"" } diff --git a/src/py/flwr/cli/new/templates/app/requirements.numpy.txt.tpl b/src/py/flwr/cli/new/templates/app/requirements.numpy.txt.tpl new file mode 100644 index 000000000000..4b460798e96f --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/requirements.numpy.txt.tpl @@ -0,0 +1,2 @@ +flwr>=1.8, <2.0 +numpy>=1.21.0 diff --git a/src/py/flwr/cli/new/templates/app/requirements.pytorch.txt.tpl b/src/py/flwr/cli/new/templates/app/requirements.pytorch.txt.tpl new file mode 100644 index 000000000000..ddb8a814447b --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/requirements.pytorch.txt.tpl @@ -0,0 +1,4 @@ +flwr-nightly[simulation]==1.8.0.dev20240313 +flwr-datasets[vision]==0.0.2 +torch==2.2.1 +torchvision==0.17.1 diff --git a/src/py/flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl b/src/py/flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl new file mode 100644 index 000000000000..b6fb49a4bbcb --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl @@ -0,0 +1,4 @@ +flwr>=1.8, <2.0 +flwr-datasets[vision]>=0.0.2, <1.0.0 +tensorflow-macos>=2.9.1, !=2.11.1 ; sys_platform == "darwin" and platform_machine == "arm64" +tensorflow-cpu>=2.9.1, !=2.11.1 ; platform_machine == "x86_64" diff --git a/src/py/flwr/cli/run/__init__.py b/src/py/flwr/cli/run/__init__.py new file mode 100644 index 000000000000..43523c215d3e --- /dev/null +++ b/src/py/flwr/cli/run/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower command line interface `run` command.""" + +from .run import run as run + +__all__ = [ + "run", +] diff --git a/src/py/flwr/cli/run/run.py b/src/py/flwr/cli/run/run.py new file mode 100644 index 000000000000..98b5da1843a6 --- /dev/null +++ b/src/py/flwr/cli/run/run.py @@ -0,0 +1,68 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower command line interface `run` command.""" + +import sys + +import typer + +from flwr.cli import flower_toml +from flwr.simulation.run_simulation import _run_simulation + + +def run() -> None: + """Run Flower project.""" + typer.secho("Loading project configuration... ", fg=typer.colors.BLUE) + + config, errors, warnings = flower_toml.load_and_validate_with_defaults() + + if config is None: + typer.secho( + "Project configuration could not be loaded.\nflower.toml is invalid:\n" + + "\n".join([f"- {line}" for line in errors]), + fg=typer.colors.RED, + bold=True, + ) + sys.exit() + + if warnings: + typer.secho( + "Project configuration is missing the following " + "recommended properties:\n" + "\n".join([f"- {line}" for line in warnings]), + fg=typer.colors.RED, + bold=True, + ) + + typer.secho("Success", fg=typer.colors.GREEN) + + server_app_ref = config["flower"]["components"]["serverapp"] + client_app_ref = config["flower"]["components"]["clientapp"] + engine = config["flower"]["engine"]["name"] + + if engine == "simulation": + num_supernodes = config["flower"]["engine"]["simulation"]["supernode"]["num"] + + typer.secho("Starting run... ", fg=typer.colors.BLUE) + _run_simulation( + server_app_attr=server_app_ref, + client_app_attr=client_app_ref, + num_supernodes=num_supernodes, + ) + else: + typer.secho( + f"Engine '{engine}' is not yet supported in `flwr run`", + fg=typer.colors.RED, + bold=True, + ) diff --git a/src/py/flwr/cli/utils.py b/src/py/flwr/cli/utils.py new file mode 100644 index 000000000000..4e86f0c3b8c8 --- /dev/null +++ b/src/py/flwr/cli/utils.py @@ -0,0 +1,67 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower command line interface utils.""" + +from typing import List, cast + +import typer + + +def prompt_text(text: str) -> str: + """Ask user to enter text input.""" + while True: + result = typer.prompt( + typer.style(f"\n💬 {text}", fg=typer.colors.MAGENTA, bold=True) + ) + if len(result) > 0: + break + print(typer.style("❌ Invalid entry", fg=typer.colors.RED, bold=True)) + + return cast(str, result) + + +def prompt_options(text: str, options: List[str]) -> str: + """Ask user to select one of the given options and return the selected item.""" + # Turn options into a list with index as in " [ 0] quickstart-pytorch" + options_formatted = [ + " [ " + + typer.style(index, fg=typer.colors.GREEN, bold=True) + + "]" + + f" {typer.style(name, fg=typer.colors.WHITE, bold=True)}" + for index, name in enumerate(options) + ] + + while True: + index = typer.prompt( + "\n" + + typer.style(f"💬 {text}", fg=typer.colors.MAGENTA, bold=True) + + "\n\n" + + "\n".join(options_formatted) + + "\n\n\n" + ) + try: + options[int(index)] # pylint: disable=expression-not-assigned + break + except IndexError: + print(typer.style("❌ Index out of range", fg=typer.colors.RED, bold=True)) + continue + except ValueError: + print( + typer.style("❌ Please choose a number", fg=typer.colors.RED, bold=True) + ) + continue + + result = options[int(index)] + return result diff --git a/src/py/flwr/client/__init__.py b/src/py/flwr/client/__init__.py index 13540a76cc25..a721fb584164 100644 --- a/src/py/flwr/client/__init__.py +++ b/src/py/flwr/client/__init__.py @@ -15,18 +15,20 @@ """Flower client.""" -from .app import run_client as run_client +from .app import run_client_app as run_client_app from .app import start_client as start_client from .app import start_numpy_client as start_numpy_client from .client import Client as Client +from .client_app import ClientApp as ClientApp from .numpy_client import NumPyClient as NumPyClient from .typing import ClientFn as ClientFn __all__ = [ "Client", + "ClientApp", "ClientFn", "NumPyClient", - "run_client", + "run_client_app", "start_client", "start_numpy_client", ] diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 3448e18e20c5..c8287afc0fd0 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -18,14 +18,16 @@ import argparse import sys import time -from logging import INFO, WARN +from logging import DEBUG, INFO, WARN from pathlib import Path -from typing import Callable, ContextManager, Optional, Tuple, Union +from typing import Callable, ContextManager, Optional, Tuple, Type, Union + +from grpc import RpcError from flwr.client.client import Client -from flwr.client.flower import Flower -from flwr.client.typing import Bwd, ClientFn, Fwd -from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event +from flwr.client.client_app import ClientApp, LoadClientAppError +from flwr.client.typing import ClientFn +from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, Message, event from flwr.common.address import parse_address from flwr.common.constant import ( MISSING_EXTRA_REST, @@ -34,10 +36,11 @@ TRANSPORT_TYPE_REST, TRANSPORT_TYPES, ) -from flwr.common.logger import log, warn_experimental_feature -from flwr.proto.task_pb2 import TaskIns, TaskRes +from flwr.common.exit_handlers import register_exit_handlers +from flwr.common.logger import log, warn_deprecated_feature, warn_experimental_feature +from flwr.common.object_ref import load_app, validate +from flwr.common.retry_invoker import RetryInvoker, exponential -from .flower import load_flower_callable from .grpc_client.connection import grpc_connection from .grpc_rere_client.connection import grpc_request_response from .message_handler.message_handler import handle_control_message @@ -45,13 +48,13 @@ from .numpy_client import NumPyClient -def run_client() -> None: - """Run Flower client.""" - event(EventType.RUN_CLIENT_ENTER) +def run_client_app() -> None: + """Run Flower client app.""" + event(EventType.RUN_CLIENT_APP_ENTER) log(INFO, "Long-running Flower client starting") - args = _parse_args_client().parse_args() + args = _parse_args_run_client_app().parse_args() # Obtain certificates if args.insecure: @@ -62,7 +65,12 @@ def run_client() -> None: "the '--root-certificates' option when running in insecure mode, " "or omit '--insecure' to use HTTPS." ) - log(WARN, "Option `--insecure` was set. Starting insecure HTTP client.") + log( + WARN, + "Option `--insecure` was set. " + "Starting insecure HTTP client connected to %s.", + args.server, + ) root_certificates = None else: # Load the certificates if provided, or load the system certificates @@ -71,42 +79,72 @@ def run_client() -> None: root_certificates = None else: root_certificates = Path(cert_path).read_bytes() + log( + DEBUG, + "Starting secure HTTPS client connected to %s " + "with the following certificates: %s.", + args.server, + cert_path, + ) - print(args.root_certificates) - print(args.server) - print(args.callable_dir) - print(args.callable) + log( + DEBUG, + "Flower will load ClientApp `%s`", + getattr(args, "client-app"), + ) + + client_app_dir = args.dir + if client_app_dir is not None: + sys.path.insert(0, client_app_dir) + + app_ref: str = getattr(args, "client-app") + valid, error_msg = validate(app_ref) + if not valid and error_msg: + raise LoadClientAppError(error_msg) from None - callable_dir = args.callable_dir - if callable_dir is not None: - sys.path.insert(0, callable_dir) + def _load() -> ClientApp: + client_app = load_app(app_ref, LoadClientAppError) - def _load() -> Flower: - flower: Flower = load_flower_callable(args.callable) - return flower + if not isinstance(client_app, ClientApp): + raise LoadClientAppError( + f"Attribute {app_ref} is not of type {ClientApp}", + ) from None + + return client_app _start_client_internal( server_address=args.server, - load_flower_callable_fn=_load, - transport="grpc-rere", # Only + load_client_app_fn=_load, + transport="rest" if args.rest else "grpc-rere", root_certificates=root_certificates, insecure=args.insecure, + max_retries=args.max_retries, + max_wait_time=args.max_wait_time, ) - event(EventType.RUN_CLIENT_LEAVE) + register_exit_handlers(event_type=EventType.RUN_CLIENT_APP_LEAVE) -def _parse_args_client() -> argparse.ArgumentParser: - """Parse command line arguments.""" +def _parse_args_run_client_app() -> argparse.ArgumentParser: + """Parse flower-client-app command line arguments.""" parser = argparse.ArgumentParser( - description="Start a long-running Flower client", + description="Start a Flower client app", ) + parser.add_argument( + "client-app", + help="For example: `client:app` or `project.package.module:wrapper.app`", + ) parser.add_argument( "--insecure", action="store_true", help="Run the client without HTTPS. By default, the client runs with " "HTTPS enabled. Use this flag only if you understand the risks.", ) + parser.add_argument( + "--rest", + action="store_true", + help="Use REST as a transport layer for the client.", + ) parser.add_argument( "--root-certificates", metavar="ROOT_CERT", @@ -120,13 +158,26 @@ def _parse_args_client() -> argparse.ArgumentParser: help="Server address", ) parser.add_argument( - "--callable", - help="For example: `client:flower` or `project.package.module:wrapper.flower`", + "--max-retries", + type=int, + default=None, + help="The maximum number of times the client will try to connect to the" + "server before giving up in case of a connection error. By default," + "it is set to None, meaning there is no limit to the number of tries.", ) parser.add_argument( - "--callable-dir", + "--max-wait-time", + type=float, + default=None, + help="The maximum duration before the client stops trying to" + "connect to the server in case of connection error. By default, it" + "is set to None, meaning there is no limit to the total time.", + ) + parser.add_argument( + "--dir", default="", - help="Add specified directory to the PYTHONPATH and load callable from there." + help="Add specified directory to the PYTHONPATH and load Flower " + "app from there." " Default: current working directory.", ) @@ -137,10 +188,12 @@ def _check_actionable_client( client: Optional[Client], client_fn: Optional[ClientFn] ) -> None: if client_fn is None and client is None: - raise Exception("Both `client_fn` and `client` are `None`, but one is required") + raise ValueError( + "Both `client_fn` and `client` are `None`, but one is required" + ) if client_fn is not None and client is not None: - raise Exception( + raise ValueError( "Both `client_fn` and `client` are provided, but only one is allowed" ) @@ -149,6 +202,7 @@ def _check_actionable_client( # pylint: disable=too-many-branches # pylint: disable=too-many-locals # pylint: disable=too-many-statements +# pylint: disable=too-many-arguments def start_client( *, server_address: str, @@ -158,6 +212,8 @@ def start_client( root_certificates: Optional[Union[bytes, str]] = None, insecure: Optional[bool] = None, transport: Optional[str] = None, + max_retries: Optional[int] = None, + max_wait_time: Optional[float] = None, ) -> None: """Start a Flower client node which connects to a Flower server. @@ -191,6 +247,14 @@ class `flwr.client.Client` (default: None) - 'grpc-bidi': gRPC, bidirectional streaming - 'grpc-rere': gRPC, request-response (experimental) - 'rest': HTTP (experimental) + max_retries: Optional[int] (default: None) + The maximum number of times the client will try to connect to the + server before giving up in case of a connection error. If set to None, + there is no limit to the number of tries. + max_wait_time: Optional[float] (default: None) + The maximum duration before the client stops trying to + connect to the server in case of connection error. + If set to None, there is no limit to the total time. Examples -------- @@ -225,13 +289,15 @@ class `flwr.client.Client` (default: None) event(EventType.START_CLIENT_ENTER) _start_client_internal( server_address=server_address, - load_flower_callable_fn=None, + load_client_app_fn=None, client_fn=client_fn, client=client, grpc_max_message_length=grpc_max_message_length, root_certificates=root_certificates, insecure=insecure, transport=transport, + max_retries=max_retries, + max_wait_time=max_wait_time, ) event(EventType.START_CLIENT_LEAVE) @@ -243,13 +309,15 @@ class `flwr.client.Client` (default: None) def _start_client_internal( *, server_address: str, - load_flower_callable_fn: Optional[Callable[[], Flower]] = None, + load_client_app_fn: Optional[Callable[[], ClientApp]] = None, client_fn: Optional[ClientFn] = None, client: Optional[Client] = None, grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, root_certificates: Optional[Union[bytes, str]] = None, insecure: Optional[bool] = None, transport: Optional[str] = None, + max_retries: Optional[int] = None, + max_wait_time: Optional[float] = None, ) -> None: """Start a Flower client node which connects to a Flower server. @@ -259,8 +327,8 @@ def _start_client_internal( The IPv4 or IPv6 address of the server. If the Flower server runs on the same machine on port 8080, then `server_address` would be `"[::]:8080"`. - load_flower_callable_fn : Optional[Callable[[], Flower]] (default: None) - A function that can be used to load a `Flower` callable instance. + load_client_app_fn : Optional[Callable[[], ClientApp]] (default: None) + A function that can be used to load a `ClientApp` instance. client_fn : Optional[ClientFn] A callable that instantiates a Client. (default: None) client : Optional[flwr.client.Client] @@ -277,7 +345,7 @@ class `flwr.client.Client` (default: None) The PEM-encoded root certificates as a byte string or a path string. If provided, a secure connection using the certificates will be established to an SSL-enabled Flower server. - insecure : bool (default: True) + insecure : Optional[bool] (default: None) Starts an insecure gRPC connection when True. Enables HTTPS connection when False, using system certificates if `root_certificates` is None. transport : Optional[str] (default: None) @@ -285,11 +353,19 @@ class `flwr.client.Client` (default: None) - 'grpc-bidi': gRPC, bidirectional streaming - 'grpc-rere': gRPC, request-response (experimental) - 'rest': HTTP (experimental) + max_retries: Optional[int] (default: None) + The maximum number of times the client will try to connect to the + server before giving up in case of a connection error. If set to None, + there is no limit to the number of tries. + max_wait_time: Optional[float] (default: None) + The maximum duration before the client stops trying to + connect to the server in case of connection error. + If set to None, there is no limit to the total time. """ if insecure is None: insecure = root_certificates is None - if load_flower_callable_fn is None: + if load_client_app_fn is None: _check_actionable_client(client, client_fn) if client_fn is None: @@ -298,25 +374,63 @@ def single_client_factory( cid: str, # pylint: disable=unused-argument ) -> Client: if client is None: # Added this to keep mypy happy - raise Exception( + raise ValueError( "Both `client_fn` and `client` are `None`, but one is required" ) return client # Always return the same instance client_fn = single_client_factory - def _load_app() -> Flower: - return Flower(client_fn=client_fn) + def _load_client_app() -> ClientApp: + return ClientApp(client_fn=client_fn) - load_flower_callable_fn = _load_app + load_client_app_fn = _load_client_app else: - warn_experimental_feature("`load_flower_callable_fn`") + warn_experimental_feature("`load_client_app_fn`") - # At this point, only `load_flower_callable_fn` should be used + # At this point, only `load_client_app_fn` should be used # Both `client` and `client_fn` must not be used directly # Initialize connection context manager - connection, address = _init_connection(transport, server_address) + connection, address, connection_error_type = _init_connection( + transport, server_address + ) + + retry_invoker = RetryInvoker( + wait_factory=exponential, + recoverable_exceptions=connection_error_type, + max_tries=max_retries, + max_time=max_wait_time, + on_giveup=lambda retry_state: ( + log( + WARN, + "Giving up reconnection after %.2f seconds and %s tries.", + retry_state.elapsed_time, + retry_state.tries, + ) + if retry_state.tries > 1 + else None + ), + on_success=lambda retry_state: ( + log( + INFO, + "Connection successful after %.2f seconds and %s tries.", + retry_state.elapsed_time, + retry_state.tries, + ) + if retry_state.tries > 1 + else None + ), + on_backoff=lambda retry_state: ( + log(WARN, "Connection attempt failed, retrying...") + if retry_state.tries == 1 + else log( + DEBUG, + "Connection attempt failed, retrying in %.2f seconds", + retry_state.actual_wait, + ) + ), + ) node_state = NodeState() @@ -325,6 +439,7 @@ def _load_app() -> Flower: with connection( address, insecure, + retry_invoker, grpc_max_message_length, root_certificates, ) as conn: @@ -336,40 +451,63 @@ def _load_app() -> Flower: while True: # Receive - task_ins = receive() - if task_ins is None: + message = receive() + if message is None: time.sleep(3) # Wait for 3s before asking again continue + log(INFO, "") + log( + INFO, + "[RUN %s, ROUND %s]", + message.metadata.run_id, + message.metadata.group_id, + ) + log( + INFO, + "Received: %s message %s", + message.metadata.message_type, + message.metadata.message_id, + ) + # Handle control message - task_res, sleep_duration = handle_control_message(task_ins=task_ins) - if task_res: - send(task_res) + out_message, sleep_duration = handle_control_message(message) + if out_message: + send(out_message) break - # Register state - node_state.register_workloadstate(workload_id=task_ins.workload_id) + # Register context for this run + node_state.register_context(run_id=message.metadata.run_id) + + # Retrieve context for this run + context = node_state.retrieve_context(run_id=message.metadata.run_id) - # Load app - app: Flower = load_flower_callable_fn() + # Load ClientApp instance + client_app: ClientApp = load_client_app_fn() # Handle task message - fwd_msg: Fwd = Fwd( - task_ins=task_ins, - state=node_state.retrieve_workloadstate( - workload_id=task_ins.workload_id - ), - ) - bwd_msg: Bwd = app(fwd=fwd_msg) + out_message = client_app(message=message, context=context) # Update node state - node_state.update_workloadstate( - workload_id=bwd_msg.task_res.workload_id, - workload_state=bwd_msg.state, + node_state.update_context( + run_id=message.metadata.run_id, + context=context, ) # Send - send(bwd_msg.task_res) + send(out_message) + log( + INFO, + "[RUN %s, ROUND %s]", + out_message.metadata.run_id, + out_message.metadata.group_id, + ) + log( + INFO, + "Sent: %s reply to message %s", + out_message.metadata.message_type, + message.metadata.message_id, + ) # Unregister node if delete_node is not None: @@ -398,6 +536,12 @@ def start_numpy_client( ) -> None: """Start a Flower NumPyClient which connects to a gRPC server. + Warning + ------- + This function is deprecated since 1.7.0. Use :code:`flwr.client.start_client` + instead and first convert your :code:`NumPyClient` to type + :code:`flwr.client.Client` by executing its :code:`to_client()` method. + Parameters ---------- server_address : str @@ -453,21 +597,22 @@ def start_numpy_client( >>> root_certificates=Path("/crts/root.pem").read_bytes(), >>> ) """ - # warnings.warn( - # "flwr.client.start_numpy_client() is deprecated and will " - # "be removed in a future version of Flower. Instead, pass " - # "your client to `flwr.client.start_client()` by calling " - # "first the `.to_client()` method as shown below: \n" - # "\tflwr.client.start_client(\n" - # "\t\tserver_address=':',\n" - # "\t\tclient=FlowerClient().to_client()\n" - # "\t)", - # DeprecationWarning, - # stacklevel=2, - # ) + mssg = ( + "flwr.client.start_numpy_client() is deprecated. \n\tInstead, use " + "`flwr.client.start_client()` by ensuring you first call " + "the `.to_client()` method as shown below: \n" + "\tflwr.client.start_client(\n" + "\t\tserver_address=':',\n" + "\t\tclient=FlowerClient().to_client()," + " # <-- where FlowerClient is of type flwr.client.NumPyClient object\n" + "\t)\n" + "\tUsing `start_numpy_client()` is deprecated." + ) + + warn_deprecated_feature(name=mssg) # Calling this function is deprecated. A warning is thrown. - # We first need to convert either the supplied client to `Client.` + # We first need to convert the supplied client to `Client.` wrp_client = client.to_client() @@ -481,21 +626,20 @@ def start_numpy_client( ) -def _init_connection( - transport: Optional[str], server_address: str -) -> Tuple[ +def _init_connection(transport: Optional[str], server_address: str) -> Tuple[ Callable[ - [str, bool, int, Union[bytes, str, None]], + [str, bool, RetryInvoker, int, Union[bytes, str, None]], ContextManager[ Tuple[ - Callable[[], Optional[TaskIns]], - Callable[[TaskRes], None], + Callable[[], Optional[Message]], + Callable[[Message], None], Optional[Callable[[], None]], Optional[Callable[[], None]], ] ], ], str, + Type[Exception], ]: # Parse IP address parsed_address = parse_address(server_address) @@ -511,6 +655,8 @@ def _init_connection( # Use either gRPC bidirectional streaming or REST request/response if transport == TRANSPORT_TYPE_REST: try: + from requests.exceptions import ConnectionError as RequestsConnectionError + from .rest_client.connection import http_request_response except ModuleNotFoundError: sys.exit(MISSING_EXTRA_REST) @@ -519,14 +665,14 @@ def _init_connection( "When using the REST API, please provide `https://` or " "`http://` before the server address (e.g. `http://127.0.0.1:8080`)" ) - connection = http_request_response + connection, error_type = http_request_response, RequestsConnectionError elif transport == TRANSPORT_TYPE_GRPC_RERE: - connection = grpc_request_response + connection, error_type = grpc_request_response, RpcError elif transport == TRANSPORT_TYPE_GRPC_BIDI: - connection = grpc_connection + connection, error_type = grpc_connection, RpcError else: raise ValueError( f"Unknown transport type: {transport} (possible: {TRANSPORT_TYPES})" ) - return connection, address + return connection, address, error_type diff --git a/src/py/flwr/client/app_test.py b/src/py/flwr/client/app_test.py index 7ef6410debad..56d6308a0fe2 100644 --- a/src/py/flwr/client/app_test.py +++ b/src/py/flwr/client/app_test.py @@ -41,19 +41,19 @@ class PlainClient(Client): def get_properties(self, ins: GetPropertiesIns) -> GetPropertiesRes: """Raise an Exception because this method is not expected to be called.""" - raise Exception() + raise NotImplementedError() def get_parameters(self, ins: GetParametersIns) -> GetParametersRes: """Raise an Exception because this method is not expected to be called.""" - raise Exception() + raise NotImplementedError() def fit(self, ins: FitIns) -> FitRes: """Raise an Exception because this method is not expected to be called.""" - raise Exception() + raise NotImplementedError() def evaluate(self, ins: EvaluateIns) -> EvaluateRes: """Raise an Exception because this method is not expected to be called.""" - raise Exception() + raise NotImplementedError() class NeedsWrappingClient(NumPyClient): @@ -61,23 +61,23 @@ class NeedsWrappingClient(NumPyClient): def get_properties(self, config: Config) -> Dict[str, Scalar]: """Raise an Exception because this method is not expected to be called.""" - raise Exception() + raise NotImplementedError() def get_parameters(self, config: Config) -> NDArrays: """Raise an Exception because this method is not expected to be called.""" - raise Exception() + raise NotImplementedError() def fit( self, parameters: NDArrays, config: Config ) -> Tuple[NDArrays, int, Dict[str, Scalar]]: """Raise an Exception because this method is not expected to be called.""" - raise Exception() + raise NotImplementedError() def evaluate( self, parameters: NDArrays, config: Config ) -> Tuple[float, int, Dict[str, Scalar]]: """Raise an Exception because this method is not expected to be called.""" - raise Exception() + raise NotImplementedError() def test_to_client_with_client() -> None: diff --git a/src/py/flwr/client/client.py b/src/py/flwr/client/client.py index 280e0a8ca989..23a3755f3efe 100644 --- a/src/py/flwr/client/client.py +++ b/src/py/flwr/client/client.py @@ -19,9 +19,9 @@ from abc import ABC -from flwr.client.workload_state import WorkloadState from flwr.common import ( Code, + Context, EvaluateIns, EvaluateRes, FitIns, @@ -38,7 +38,7 @@ class Client(ABC): """Abstract base class for Flower clients.""" - state: WorkloadState + context: Context def get_properties(self, ins: GetPropertiesIns) -> GetPropertiesRes: """Return set of client's properties. @@ -141,13 +141,13 @@ def evaluate(self, ins: EvaluateIns) -> EvaluateRes: metrics={}, ) - def get_state(self) -> WorkloadState: - """Get the workload state from this client.""" - return self.state + def get_context(self) -> Context: + """Get the run context from this client.""" + return self.context - def set_state(self, state: WorkloadState) -> None: - """Apply a workload state to this client.""" - self.state = state + def set_context(self, context: Context) -> None: + """Apply a run context to this client.""" + self.context = context def to_client(self) -> Client: """Return client (itself).""" diff --git a/src/py/flwr/client/client_app.py b/src/py/flwr/client/client_app.py new file mode 100644 index 000000000000..ad7a01326991 --- /dev/null +++ b/src/py/flwr/client/client_app.py @@ -0,0 +1,224 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower ClientApp.""" + + +from typing import Callable, List, Optional + +from flwr.client.message_handler.message_handler import ( + handle_legacy_message_from_msgtype, +) +from flwr.client.mod.utils import make_ffn +from flwr.client.typing import ClientFn, Mod +from flwr.common import Context, Message, MessageType + +from .typing import ClientAppCallable + + +class ClientApp: + """Flower ClientApp. + + Examples + -------- + Assuming a typical `Client` implementation named `FlowerClient`, you can wrap it in + a `ClientApp` as follows: + + >>> class FlowerClient(NumPyClient): + >>> # ... + >>> + >>> def client_fn(cid): + >>> return FlowerClient().to_client() + >>> + >>> app = ClientApp(client_fn) + + If the above code is in a Python module called `client`, it can be started as + follows: + + >>> flower-client-app client:app --insecure + + In this `client:app` example, `client` refers to the Python module `client.py` in + which the previous code lives in and `app` refers to the global attribute `app` that + points to an object of type `ClientApp`. + """ + + def __init__( + self, + client_fn: Optional[ClientFn] = None, # Only for backward compatibility + mods: Optional[List[Mod]] = None, + ) -> None: + self._mods: List[Mod] = mods if mods is not None else [] + + # Create wrapper function for `handle` + self._call: Optional[ClientAppCallable] = None + if client_fn is not None: + + def ffn( + message: Message, + context: Context, + ) -> Message: # pylint: disable=invalid-name + out_message = handle_legacy_message_from_msgtype( + client_fn=client_fn, message=message, context=context + ) + return out_message + + # Wrap mods around the wrapped handle function + self._call = make_ffn(ffn, mods if mods is not None else []) + + # Step functions + self._train: Optional[ClientAppCallable] = None + self._evaluate: Optional[ClientAppCallable] = None + self._query: Optional[ClientAppCallable] = None + + def __call__(self, message: Message, context: Context) -> Message: + """Execute `ClientApp`.""" + # Execute message using `client_fn` + if self._call: + return self._call(message, context) + + # Execute message using a new + if message.metadata.message_type == MessageType.TRAIN: + if self._train: + return self._train(message, context) + raise ValueError("No `train` function registered") + if message.metadata.message_type == MessageType.EVALUATE: + if self._evaluate: + return self._evaluate(message, context) + raise ValueError("No `evaluate` function registered") + if message.metadata.message_type == MessageType.QUERY: + if self._query: + return self._query(message, context) + raise ValueError("No `query` function registered") + + # Message type did not match one of the known message types abvoe + raise ValueError(f"Unknown message_type: {message.metadata.message_type}") + + def train(self) -> Callable[[ClientAppCallable], ClientAppCallable]: + """Return a decorator that registers the train fn with the client app. + + Examples + -------- + >>> app = ClientApp() + >>> + >>> @app.train() + >>> def train(message: Message, context: Context) -> Message: + >>> print("ClientApp training running") + >>> # Create and return an echo reply message + >>> return message.create_reply(content=message.content(), ttl="") + """ + + def train_decorator(train_fn: ClientAppCallable) -> ClientAppCallable: + """Register the train fn with the ServerApp object.""" + if self._call: + raise _registration_error(MessageType.TRAIN) + + # Register provided function with the ClientApp object + # Wrap mods around the wrapped step function + self._train = make_ffn(train_fn, self._mods) + + # Return provided function unmodified + return train_fn + + return train_decorator + + def evaluate(self) -> Callable[[ClientAppCallable], ClientAppCallable]: + """Return a decorator that registers the evaluate fn with the client app. + + Examples + -------- + >>> app = ClientApp() + >>> + >>> @app.evaluate() + >>> def evaluate(message: Message, context: Context) -> Message: + >>> print("ClientApp evaluation running") + >>> # Create and return an echo reply message + >>> return message.create_reply(content=message.content(), ttl="") + """ + + def evaluate_decorator(evaluate_fn: ClientAppCallable) -> ClientAppCallable: + """Register the evaluate fn with the ServerApp object.""" + if self._call: + raise _registration_error(MessageType.EVALUATE) + + # Register provided function with the ClientApp object + # Wrap mods around the wrapped step function + self._evaluate = make_ffn(evaluate_fn, self._mods) + + # Return provided function unmodified + return evaluate_fn + + return evaluate_decorator + + def query(self) -> Callable[[ClientAppCallable], ClientAppCallable]: + """Return a decorator that registers the query fn with the client app. + + Examples + -------- + >>> app = ClientApp() + >>> + >>> @app.query() + >>> def query(message: Message, context: Context) -> Message: + >>> print("ClientApp query running") + >>> # Create and return an echo reply message + >>> return message.create_reply(content=message.content(), ttl="") + """ + + def query_decorator(query_fn: ClientAppCallable) -> ClientAppCallable: + """Register the query fn with the ServerApp object.""" + if self._call: + raise _registration_error(MessageType.QUERY) + + # Register provided function with the ClientApp object + # Wrap mods around the wrapped step function + self._query = make_ffn(query_fn, self._mods) + + # Return provided function unmodified + return query_fn + + return query_decorator + + +class LoadClientAppError(Exception): + """Error when trying to load `ClientApp`.""" + + +def _registration_error(fn_name: str) -> ValueError: + return ValueError( + f"""Use either `@app.{fn_name}()` or `client_fn`, but not both. + + Use the `ClientApp` with an existing `client_fn`: + + >>> class FlowerClient(NumPyClient): + >>> # ... + >>> + >>> def client_fn(cid) -> Client: + >>> return FlowerClient().to_client() + >>> + >>> app = ClientApp() + >>> client_fn=client_fn, + >>> ) + + Use the `ClientApp` with a custom {fn_name} function: + + >>> app = ClientApp() + >>> + >>> @app.{fn_name}() + >>> def {fn_name}(message: Message, context: Context) -> Message: + >>> print("ClientApp {fn_name} running") + >>> # Create and return an echo reply message + >>> return message.create_reply( + >>> content=message.content(), ttl="" + >>> ) + """, + ) diff --git a/src/py/flwr/client/dpfedavg_numpy_client.py b/src/py/flwr/client/dpfedavg_numpy_client.py index 41b4d676df43..ab31a289d29b 100644 --- a/src/py/flwr/client/dpfedavg_numpy_client.py +++ b/src/py/flwr/client/dpfedavg_numpy_client.py @@ -22,13 +22,20 @@ from flwr.client.numpy_client import NumPyClient from flwr.common.dp import add_gaussian_noise, clip_by_l2 +from flwr.common.logger import warn_deprecated_feature from flwr.common.typing import Config, NDArrays, Scalar class DPFedAvgNumPyClient(NumPyClient): - """Wrapper for configuring a Flower client for DP.""" + """Wrapper for configuring a Flower client for DP. + + Warning + ------- + This class is deprecated and will be removed in a future release. + """ def __init__(self, client: NumPyClient) -> None: + warn_deprecated_feature("`DPFedAvgNumPyClient` wrapper") super().__init__() self.client = client @@ -117,16 +124,16 @@ def fit( update = [np.subtract(x, y) for (x, y) in zip(updated_params, original_params)] if "dpfedavg_clip_norm" not in config: - raise Exception("Clipping threshold not supplied by the server.") + raise KeyError("Clipping threshold not supplied by the server.") if not isinstance(config["dpfedavg_clip_norm"], float): - raise Exception("Clipping threshold should be a floating point value.") + raise TypeError("Clipping threshold should be a floating point value.") # Clipping update, clipped = clip_by_l2(update, config["dpfedavg_clip_norm"]) if "dpfedavg_noise_stddev" in config: if not isinstance(config["dpfedavg_noise_stddev"], float): - raise Exception( + raise TypeError( "Scale of noise to be added should be a floating point value." ) # Noising @@ -138,7 +145,7 @@ def fit( # Calculating value of norm indicator bit, required for adaptive clipping if "dpfedavg_adaptive_clip_enabled" in config: if not isinstance(config["dpfedavg_adaptive_clip_enabled"], bool): - raise Exception( + raise TypeError( "dpfedavg_adaptive_clip_enabled should be a boolean-valued flag." ) metrics["dpfedavg_norm_bit"] = not clipped diff --git a/src/py/flwr/client/flower.py b/src/py/flwr/client/flower.py deleted file mode 100644 index 10c78ec45b44..000000000000 --- a/src/py/flwr/client/flower.py +++ /dev/null @@ -1,117 +0,0 @@ -# Copyright 2023 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Flower callable.""" - - -import importlib -from typing import cast - -from flwr.client.message_handler.message_handler import handle -from flwr.client.typing import Bwd, ClientFn, Fwd - - -class Flower: - """Flower callable. - - Examples - -------- - Assuming a typical client implementation in `FlowerClient`, you can wrap it in a - Flower callable as follows: - - >>> class FlowerClient(NumPyClient): - >>> # ... - >>> - >>> def client_fn(cid): - >>> return FlowerClient().to_client() - >>> - >>> flower = Flower(client_fn) - - If the above code is in a Python module called `client`, it can be started as - follows: - - >>> flower-client --callable client:flower - - In this `client:flower` example, `client` refers to the Python module in which the - previous code lives in. `flower` refers to the global attribute `flower` that points - to an object of type `Flower` (a Flower callable). - """ - - def __init__( - self, - client_fn: ClientFn, # Only for backward compatibility - ) -> None: - self.client_fn = client_fn - - def __call__(self, fwd: Fwd) -> Bwd: - """.""" - # Execute the task - task_res, state_updated = handle( - client_fn=self.client_fn, - state=fwd.state, - task_ins=fwd.task_ins, - ) - return Bwd( - task_res=task_res, - state=state_updated, - ) - - -class LoadCallableError(Exception): - """.""" - - -def load_flower_callable(module_attribute_str: str) -> Flower: - """Load the `Flower` object specified in a module attribute string. - - The module/attribute string should have the form :. Valid - examples include `client:flower` and `project.package.module:wrapper.flower`. It - must refer to a module on the PYTHONPATH, the module needs to have the specified - attribute, and the attribute must be of type `Flower`. - """ - module_str, _, attributes_str = module_attribute_str.partition(":") - if not module_str: - raise LoadCallableError( - f"Missing module in {module_attribute_str}", - ) from None - if not attributes_str: - raise LoadCallableError( - f"Missing attribute in {module_attribute_str}", - ) from None - - # Load module - try: - module = importlib.import_module(module_str) - except ModuleNotFoundError: - raise LoadCallableError( - f"Unable to load module {module_str}", - ) from None - - # Recursively load attribute - attribute = module - try: - for attribute_str in attributes_str.split("."): - attribute = getattr(attribute, attribute_str) - except AttributeError: - raise LoadCallableError( - f"Unable to load attribute {attributes_str} from module {module_str}", - ) from None - - # Check type - if not isinstance(attribute, Flower): - raise LoadCallableError( - f"Attribute {attributes_str} is not of type {Flower}", - ) from None - - return cast(Flower, attribute) diff --git a/src/py/flwr/client/grpc_client/connection.py b/src/py/flwr/client/grpc_client/connection.py index 335d28e72828..163a58542c9e 100644 --- a/src/py/flwr/client/grpc_client/connection.py +++ b/src/py/flwr/client/grpc_client/connection.py @@ -20,15 +20,27 @@ from logging import DEBUG from pathlib import Path from queue import Queue -from typing import Callable, Iterator, Optional, Tuple, Union - -from flwr.common import GRPC_MAX_MESSAGE_LENGTH +from typing import Callable, Iterator, Optional, Tuple, Union, cast + +from flwr.common import ( + GRPC_MAX_MESSAGE_LENGTH, + ConfigsRecord, + Message, + Metadata, + RecordSet, +) +from flwr.common import recordset_compat as compat +from flwr.common import serde +from flwr.common.constant import MessageType, MessageTypeLegacy from flwr.common.grpc import create_channel from flwr.common.logger import log -from flwr.proto.node_pb2 import Node -from flwr.proto.task_pb2 import Task, TaskIns, TaskRes -from flwr.proto.transport_pb2 import ClientMessage, ServerMessage -from flwr.proto.transport_pb2_grpc import FlowerServiceStub +from flwr.common.retry_invoker import RetryInvoker +from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 + ClientMessage, + Reason, + ServerMessage, +) +from flwr.proto.transport_pb2_grpc import FlowerServiceStub # pylint: disable=E0611 # The following flags can be uncommented for debugging. Other possible values: # https://github.com/grpc/grpc/blob/master/doc/environment_variables.md @@ -43,15 +55,16 @@ def on_channel_state_change(channel_connectivity: str) -> None: @contextmanager -def grpc_connection( +def grpc_connection( # pylint: disable=R0915 server_address: str, insecure: bool, + retry_invoker: RetryInvoker, # pylint: disable=unused-argument max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, root_certificates: Optional[Union[bytes, str]] = None, ) -> Iterator[ Tuple[ - Callable[[], Optional[TaskIns]], - Callable[[TaskRes], None], + Callable[[], Optional[Message]], + Callable[[Message], None], Optional[Callable[[], None]], Optional[Callable[[], None]], ] @@ -64,6 +77,11 @@ def grpc_connection( The IPv4 or IPv6 address of the server. If the Flower server runs on the same machine on port 8080, then `server_address` would be `"0.0.0.0:8080"` or `"[::]:8080"`. + insecure : bool + Starts an insecure gRPC connection when True. Enables HTTPS connection + when False, using system certificates if `root_certificates` is None. + retry_invoker: RetryInvoker + Unused argument present for compatibilty. max_message_length : int The maximum length of gRPC messages that can be exchanged with the Flower server. The default should be sufficient for most models. Users who train @@ -114,23 +132,94 @@ def grpc_connection( server_message_iterator: Iterator[ServerMessage] = stub.Join(iter(queue.get, None)) - def receive() -> TaskIns: - server_message = next(server_message_iterator) - return TaskIns( - task_id=str(uuid.uuid4()), - group_id="", - workload_id=0, - task=Task( - producer=Node(node_id=0, anonymous=True), - consumer=Node(node_id=0, anonymous=True), - ancestry=[], - legacy_server_message=server_message, + def receive() -> Message: + # Receive ServerMessage proto + proto = next(server_message_iterator) + + # ServerMessage proto --> *Ins --> RecordSet + field = proto.WhichOneof("msg") + message_type = "" + if field == "get_properties_ins": + recordset = compat.getpropertiesins_to_recordset( + serde.get_properties_ins_from_proto(proto.get_properties_ins) + ) + message_type = MessageTypeLegacy.GET_PROPERTIES + elif field == "get_parameters_ins": + recordset = compat.getparametersins_to_recordset( + serde.get_parameters_ins_from_proto(proto.get_parameters_ins) + ) + message_type = MessageTypeLegacy.GET_PARAMETERS + elif field == "fit_ins": + recordset = compat.fitins_to_recordset( + serde.fit_ins_from_proto(proto.fit_ins), False + ) + message_type = MessageType.TRAIN + elif field == "evaluate_ins": + recordset = compat.evaluateins_to_recordset( + serde.evaluate_ins_from_proto(proto.evaluate_ins), False + ) + message_type = MessageType.EVALUATE + elif field == "reconnect_ins": + recordset = RecordSet() + recordset.configs_records["config"] = ConfigsRecord( + {"seconds": proto.reconnect_ins.seconds} + ) + message_type = "reconnect" + else: + raise ValueError( + "Unsupported instruction in ServerMessage, " + "cannot deserialize from ProtoBuf" + ) + + # Construct Message + return Message( + metadata=Metadata( + run_id=0, + message_id=str(uuid.uuid4()), + src_node_id=0, + dst_node_id=0, + reply_to_message="", + group_id="", + ttl="", + message_type=message_type, ), + content=recordset, ) - def send(task_res: TaskRes) -> None: - msg = task_res.task.legacy_client_message - return queue.put(msg, block=False) + def send(message: Message) -> None: + # Retrieve RecordSet and message_type + recordset = message.content + message_type = message.metadata.message_type + + # RecordSet --> *Res --> *Res proto -> ClientMessage proto + if message_type == MessageTypeLegacy.GET_PROPERTIES: + getpropres = compat.recordset_to_getpropertiesres(recordset) + msg_proto = ClientMessage( + get_properties_res=serde.get_properties_res_to_proto(getpropres) + ) + elif message_type == MessageTypeLegacy.GET_PARAMETERS: + getparamres = compat.recordset_to_getparametersres(recordset, False) + msg_proto = ClientMessage( + get_parameters_res=serde.get_parameters_res_to_proto(getparamres) + ) + elif message_type == MessageType.TRAIN: + fitres = compat.recordset_to_fitres(recordset, False) + msg_proto = ClientMessage(fit_res=serde.fit_res_to_proto(fitres)) + elif message_type == MessageType.EVALUATE: + evalres = compat.recordset_to_evaluateres(recordset) + msg_proto = ClientMessage(evaluate_res=serde.evaluate_res_to_proto(evalres)) + elif message_type == "reconnect": + reason = cast( + Reason.ValueType, recordset.configs_records["config"]["reason"] + ) + msg_proto = ClientMessage( + disconnect_res=ClientMessage.DisconnectRes(reason=reason) + ) + else: + raise ValueError(f"Invalid message type: {message_type}") + + # Send ClientMessage proto + return queue.put(msg_proto, block=False) try: # Yield methods diff --git a/src/py/flwr/client/grpc_client/connection_test.py b/src/py/flwr/client/grpc_client/connection_test.py index e5944230e5af..b7737f511a2a 100644 --- a/src/py/flwr/client/grpc_client/connection_test.py +++ b/src/py/flwr/client/grpc_client/connection_test.py @@ -23,20 +23,53 @@ import grpc -from flwr.proto.task_pb2 import Task, TaskRes -from flwr.proto.transport_pb2 import ClientMessage, ServerMessage +from flwr.common import ConfigsRecord, Message, Metadata, RecordSet +from flwr.common import recordset_compat as compat +from flwr.common.constant import MessageTypeLegacy +from flwr.common.retry_invoker import RetryInvoker, exponential +from flwr.common.typing import Code, GetPropertiesRes, Status +from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 + ClientMessage, + ServerMessage, +) from flwr.server.client_manager import SimpleClientManager -from flwr.server.fleet.grpc_bidi.grpc_server import start_grpc_server +from flwr.server.superlink.fleet.grpc_bidi.grpc_server import start_grpc_server from .connection import grpc_connection EXPECTED_NUM_SERVER_MESSAGE = 10 -SERVER_MESSAGE = ServerMessage() +SERVER_MESSAGE = ServerMessage(get_properties_ins=ServerMessage.GetPropertiesIns()) SERVER_MESSAGE_RECONNECT = ServerMessage(reconnect_ins=ServerMessage.ReconnectIns()) -CLIENT_MESSAGE = ClientMessage() -CLIENT_MESSAGE_DISCONNECT = ClientMessage(disconnect_res=ClientMessage.DisconnectRes()) +MESSAGE_GET_PROPERTIES = Message( + metadata=Metadata( + run_id=0, + message_id="", + src_node_id=0, + dst_node_id=0, + reply_to_message="", + group_id="", + ttl="", + message_type=MessageTypeLegacy.GET_PROPERTIES, + ), + content=compat.getpropertiesres_to_recordset( + GetPropertiesRes(Status(Code.OK, ""), {}) + ), +) +MESSAGE_DISCONNECT = Message( + metadata=Metadata( + run_id=0, + message_id="", + src_node_id=0, + dst_node_id=0, + reply_to_message="", + group_id="", + ttl="", + message_type="reconnect", + ), + content=RecordSet(configs_records={"config": ConfigsRecord({"reason": 0})}), +) def unused_tcp_port() -> int: @@ -72,7 +105,9 @@ def mock_join( # type: ignore # pylint: disable=invalid-name @patch( - "flwr.server.fleet.grpc_bidi.flower_service_servicer.FlowerServiceServicer.Join", + # pylint: disable=line-too-long + "flwr.server.superlink.fleet.grpc_bidi.flower_service_servicer.FlowerServiceServicer.Join", # noqa: E501 + # pylint: enable=line-too-long mock_join, ) def test_integration_connection() -> None: @@ -93,40 +128,30 @@ def test_integration_connection() -> None: def run_client() -> int: messages_received: int = 0 - with grpc_connection(server_address=f"[::]:{port}", insecure=True) as conn: + with grpc_connection( + server_address=f"[::]:{port}", + insecure=True, + retry_invoker=RetryInvoker( + wait_factory=exponential, + recoverable_exceptions=grpc.RpcError, + max_tries=1, + max_time=None, + ), + ) as conn: receive, send, _, _ = conn # Setup processing loop while True: # Block until server responds with a message - task_ins = receive() - - if task_ins is None: - raise ValueError("Unexpected None value") - - # pylint: disable=no-member - if task_ins.HasField("task") and task_ins.task.HasField( - "legacy_server_message" - ): - server_message = task_ins.task.legacy_server_message - else: - server_message = None - # pylint: enable=no-member - - if server_message is None: - raise ValueError("Unexpected None value") + message = receive() messages_received += 1 - if server_message.HasField("reconnect_ins"): - task_res = TaskRes( - task=Task(legacy_client_message=CLIENT_MESSAGE_DISCONNECT) - ) - send(task_res) + if message.metadata.message_type == "reconnect": # type: ignore + send(MESSAGE_DISCONNECT) break # Process server_message and send client_message... - task_res = TaskRes(task=Task(legacy_client_message=CLIENT_MESSAGE)) - send(task_res) + send(MESSAGE_GET_PROPERTIES) return messages_received diff --git a/src/py/flwr/client/grpc_rere_client/connection.py b/src/py/flwr/client/grpc_rere_client/connection.py index 30d407a52c53..e6e22998b947 100644 --- a/src/py/flwr/client/grpc_rere_client/connection.py +++ b/src/py/flwr/client/grpc_rere_client/connection.py @@ -16,31 +16,31 @@ from contextlib import contextmanager +from copy import copy from logging import DEBUG, ERROR from pathlib import Path from typing import Callable, Dict, Iterator, Optional, Tuple, Union, cast -from flwr.client.message_handler.task_handler import ( - configure_task_res, - get_task_ins, - validate_task_ins, - validate_task_res, -) +from flwr.client.message_handler.message_handler import validate_out_message +from flwr.client.message_handler.task_handler import get_task_ins, validate_task_ins from flwr.common import GRPC_MAX_MESSAGE_LENGTH from flwr.common.grpc import create_channel from flwr.common.logger import log, warn_experimental_feature -from flwr.proto.fleet_pb2 import ( +from flwr.common.message import Message, Metadata +from flwr.common.retry_invoker import RetryInvoker +from flwr.common.serde import message_from_taskins, message_to_taskres +from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, DeleteNodeRequest, PullTaskInsRequest, PushTaskResRequest, ) -from flwr.proto.fleet_pb2_grpc import FleetStub -from flwr.proto.node_pb2 import Node -from flwr.proto.task_pb2 import TaskIns, TaskRes +from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611 +from flwr.proto.node_pb2 import Node # pylint: disable=E0611 +from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611 KEY_NODE = "node" -KEY_TASK_INS = "current_task_ins" +KEY_METADATA = "in_message_metadata" def on_channel_state_change(channel_connectivity: str) -> None: @@ -52,12 +52,13 @@ def on_channel_state_change(channel_connectivity: str) -> None: def grpc_request_response( server_address: str, insecure: bool, + retry_invoker: RetryInvoker, max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613 root_certificates: Optional[Union[bytes, str]] = None, ) -> Iterator[ Tuple[ - Callable[[], Optional[TaskIns]], - Callable[[TaskRes], None], + Callable[[], Optional[Message]], + Callable[[Message], None], Optional[Callable[[], None]], Optional[Callable[[], None]], ] @@ -73,6 +74,13 @@ def grpc_request_response( The IPv6 address of the server with `http://` or `https://`. If the Flower server runs on the same machine on port 8080, then `server_address` would be `"http://[::]:8080"`. + insecure : bool + Starts an insecure gRPC connection when True. Enables HTTPS connection + when False, using system certificates if `root_certificates` is None. + retry_invoker: RetryInvoker + `RetryInvoker` object that will try to reconnect the client to the server + after gRPC errors. If None, the client will only try to + reconnect once after a failure. max_message_length : int Ignored, only present to preserve API-compatibility. root_certificates : Optional[Union[bytes, str]] (default: None) @@ -101,8 +109,8 @@ def grpc_request_response( channel.subscribe(on_channel_state_change) stub = FleetStub(channel) - # Necessary state to link TaskRes to TaskIns - state: Dict[str, Optional[TaskIns]] = {KEY_TASK_INS: None} + # Necessary state to validate messages to be sent + state: Dict[str, Optional[Metadata]] = {KEY_METADATA: None} # Enable create_node and delete_node to store node node_store: Dict[str, Optional[Node]] = {KEY_NODE: None} @@ -114,7 +122,8 @@ def grpc_request_response( def create_node() -> None: """Set create_node.""" create_node_request = CreateNodeRequest() - create_node_response = stub.CreateNode( + create_node_response = retry_invoker.invoke( + stub.CreateNode, request=create_node_request, ) node_store[KEY_NODE] = create_node_response.node @@ -128,11 +137,11 @@ def delete_node() -> None: node: Node = cast(Node, node_store[KEY_NODE]) delete_node_request = DeleteNodeRequest(node=node) - stub.DeleteNode(request=delete_node_request) + retry_invoker.invoke(stub.DeleteNode, request=delete_node_request) del node_store[KEY_NODE] - def receive() -> Optional[TaskIns]: + def receive() -> Optional[Message]: """Receive next task from server.""" # Get Node if node_store[KEY_NODE] is None: @@ -142,50 +151,53 @@ def receive() -> Optional[TaskIns]: # Request instructions (task) from server request = PullTaskInsRequest(node=node) - response = stub.PullTaskIns(request=request) + response = retry_invoker.invoke(stub.PullTaskIns, request=request) # Get the current TaskIns task_ins: Optional[TaskIns] = get_task_ins(response) # Discard the current TaskIns if not valid - if task_ins is not None and not validate_task_ins( - task_ins, discard_reconnect_ins=True + if task_ins is not None and not ( + task_ins.task.consumer.node_id == node.node_id + and validate_task_ins(task_ins) ): task_ins = None - # Remember `task_ins` until `task_res` is available - state[KEY_TASK_INS] = task_ins + # Construct the Message + in_message = message_from_taskins(task_ins) if task_ins else None + + # Remember `metadata` of the in message + state[KEY_METADATA] = copy(in_message.metadata) if in_message else None - # Return the TaskIns if available - return task_ins + # Return the message if available + return in_message - def send(task_res: TaskRes) -> None: + def send(message: Message) -> None: """Send task result back to server.""" # Get Node if node_store[KEY_NODE] is None: log(ERROR, "Node instance missing") return - node: Node = cast(Node, node_store[KEY_NODE]) - # Get incoming TaskIns - if state[KEY_TASK_INS] is None: - log(ERROR, "No current TaskIns") + # Get incoming message + in_metadata = state[KEY_METADATA] + if in_metadata is None: + log(ERROR, "No current message") return - task_ins: TaskIns = cast(TaskIns, state[KEY_TASK_INS]) - # Check if fields to be set are not initialized - if not validate_task_res(task_res): - state[KEY_TASK_INS] = None - log(ERROR, "TaskRes has been initialized accidentally") + # Validate out message + if not validate_out_message(message, in_metadata): + log(ERROR, "Invalid out message") + return - # Configure TaskRes - task_res = configure_task_res(task_res, task_ins, node) + # Construct TaskRes + task_res = message_to_taskres(message) # Serialize ProtoBuf to bytes request = PushTaskResRequest(task_res_list=[task_res]) - _ = stub.PushTaskRes(request) + _ = retry_invoker.invoke(stub.PushTaskRes, request) - state[KEY_TASK_INS] = None + state[KEY_METADATA] = None try: # Yield methods diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index 0f3070cfb01a..9a5d70b1ac4d 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -15,25 +15,34 @@ """Client-side message handler.""" -from typing import Optional, Tuple +from logging import WARN +from typing import Optional, Tuple, cast from flwr.client.client import ( - Client, maybe_call_evaluate, maybe_call_fit, maybe_call_get_parameters, maybe_call_get_properties, ) -from flwr.client.message_handler.task_handler import ( - get_server_message_from_task_ins, - wrap_client_message_in_task_res, -) -from flwr.client.secure_aggregation import SecureAggregationHandler +from flwr.client.numpy_client import NumPyClient from flwr.client.typing import ClientFn -from flwr.client.workload_state import WorkloadState -from flwr.common import serde -from flwr.proto.task_pb2 import SecureAggregation, Task, TaskIns, TaskRes -from flwr.proto.transport_pb2 import ClientMessage, Reason, ServerMessage +from flwr.common import ConfigsRecord, Context, Message, Metadata, RecordSet, log +from flwr.common.constant import MessageType, MessageTypeLegacy +from flwr.common.recordset_compat import ( + evaluateres_to_recordset, + fitres_to_recordset, + getparametersres_to_recordset, + getpropertiesres_to_recordset, + recordset_to_evaluateins, + recordset_to_fitins, + recordset_to_getparametersins, + recordset_to_getpropertiesins, +) +from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 + ClientMessage, + Reason, + ServerMessage, +) class UnexpectedServerMessage(Exception): @@ -44,128 +53,97 @@ class UnknownServerMessage(Exception): """Exception indicating that the received message is unknown.""" -def handle_control_message(task_ins: TaskIns) -> Tuple[Optional[TaskRes], int]: +def handle_control_message(message: Message) -> Tuple[Optional[Message], int]: """Handle control part of the incoming message. Parameters ---------- - task_ins : TaskIns - The task instruction coming from the server, to be processed by the client. + message : Message + The Message coming from the server, to be processed by the client. Returns ------- + message : Optional[Message] + Message to be sent back to the server. If None, the client should + continue to process messages from the server. sleep_duration : int Number of seconds that the client should disconnect from the server. - keep_going : bool - Flag that indicates whether the client should continue to process the - next message from the server (True) or disconnect and optionally - reconnect later (False). """ - server_msg = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=False) - - # SecAgg message - if server_msg is None: - return None, 0 - - # ReconnectIns message - field = server_msg.WhichOneof("msg") - if field == "reconnect_ins": - disconnect_msg, sleep_duration = _reconnect(server_msg.reconnect_ins) - task_res = wrap_client_message_in_task_res(disconnect_msg) - return task_res, sleep_duration + if message.metadata.message_type == "reconnect": + # Retrieve ReconnectIns from recordset + recordset = message.content + seconds = cast(int, recordset.configs_records["config"]["seconds"]) + # Construct ReconnectIns and call _reconnect + disconnect_msg, sleep_duration = _reconnect( + ServerMessage.ReconnectIns(seconds=seconds) + ) + # Store DisconnectRes in recordset + reason = cast(int, disconnect_msg.disconnect_res.reason) + recordset = RecordSet() + recordset.configs_records["config"] = ConfigsRecord({"reason": reason}) + out_message = message.create_reply(recordset, ttl="") + # Return TaskRes and sleep duration + return out_message, sleep_duration # Any other message return None, 0 -def handle( - client_fn: ClientFn, state: WorkloadState, task_ins: TaskIns -) -> Tuple[TaskRes, WorkloadState]: - """Handle incoming TaskIns from the server. - - Parameters - ---------- - client_fn : ClientFn - A callable that instantiates a Client. - state : WorkloadState - A dataclass storing the state for the workload being executed by the client. - task_ins: TaskIns - The task instruction coming from the server, to be processed by the client. - - Returns - ------- - task_res : TaskRes - The task response that should be returned to the server. - """ - server_msg = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=False) - if server_msg is None: - # Instantiate the client - client = client_fn("-1") - client.set_state(state) - # Secure Aggregation - if task_ins.task.HasField("sa") and isinstance( - client, SecureAggregationHandler - ): - # pylint: disable-next=invalid-name - named_values = serde.named_values_from_proto(task_ins.task.sa.named_values) - res = client.handle_secure_aggregation(named_values) - task_res = TaskRes( - task_id="", - group_id="", - workload_id=0, - task=Task( - ancestry=[], - sa=SecureAggregation(named_values=serde.named_values_to_proto(res)), - ), - ) - return task_res, client.get_state() - raise NotImplementedError() - client_msg, updated_state = handle_legacy_message(client_fn, state, server_msg) - task_res = wrap_client_message_in_task_res(client_msg) - return task_res, updated_state - - -def handle_legacy_message( - client_fn: ClientFn, state: WorkloadState, server_msg: ServerMessage -) -> Tuple[ClientMessage, WorkloadState]: - """Handle incoming messages from the server. - - Parameters - ---------- - client_fn : ClientFn - A callable that instantiates a Client. - state : WorkloadState - A dataclass storing the state for the workload being executed by the client. - server_msg: ServerMessage - The message coming from the server, to be processed by the client. - - Returns - ------- - client_msg : ClientMessage - The result message that should be returned to the server. - """ - field = server_msg.WhichOneof("msg") - - # Must be handled elsewhere - if field == "reconnect_ins": - raise UnexpectedServerMessage() - - # Instantiate the client - client = client_fn("-1") - client.set_state(state) - # Execute task - message = None - if field == "get_properties_ins": - message = _get_properties(client, server_msg.get_properties_ins) - if field == "get_parameters_ins": - message = _get_parameters(client, server_msg.get_parameters_ins) - if field == "fit_ins": - message = _fit(client, server_msg.fit_ins) - if field == "evaluate_ins": - message = _evaluate(client, server_msg.evaluate_ins) - if message: - return message, client.get_state() - raise UnknownServerMessage() +def handle_legacy_message_from_msgtype( + client_fn: ClientFn, message: Message, context: Context +) -> Message: + """Handle legacy message in the inner most mod.""" + client = client_fn(str(message.metadata.partition_id)) + + # Check if NumPyClient is returend + if isinstance(client, NumPyClient): + client = client.to_client() + log( + WARN, + "Deprecation Warning: The `client_fn` function must return an instance " + "of `Client`, but an instance of `NumpyClient` was returned. " + "Please use `NumPyClient.to_client()` method to convert it to `Client`.", + ) + + client.set_context(context) + + message_type = message.metadata.message_type + + # Handle GetPropertiesIns + if message_type == MessageTypeLegacy.GET_PROPERTIES: + get_properties_res = maybe_call_get_properties( + client=client, + get_properties_ins=recordset_to_getpropertiesins(message.content), + ) + out_recordset = getpropertiesres_to_recordset(get_properties_res) + # Handle GetParametersIns + elif message_type == MessageTypeLegacy.GET_PARAMETERS: + get_parameters_res = maybe_call_get_parameters( + client=client, + get_parameters_ins=recordset_to_getparametersins(message.content), + ) + out_recordset = getparametersres_to_recordset( + get_parameters_res, keep_input=False + ) + # Handle FitIns + elif message_type == MessageType.TRAIN: + fit_res = maybe_call_fit( + client=client, + fit_ins=recordset_to_fitins(message.content, keep_input=True), + ) + out_recordset = fitres_to_recordset(fit_res, keep_input=False) + # Handle EvaluateIns + elif message_type == MessageType.EVALUATE: + evaluate_res = maybe_call_evaluate( + client=client, + evaluate_ins=recordset_to_evaluateins(message.content, keep_input=True), + ) + out_recordset = evaluateres_to_recordset(evaluate_res) + else: + raise ValueError(f"Invalid message type: {message_type}") + + # Return Message + return message.create_reply(out_recordset, ttl="") def _reconnect( @@ -182,65 +160,18 @@ def _reconnect( return ClientMessage(disconnect_res=disconnect_res), sleep_duration -def _get_properties( - client: Client, get_properties_msg: ServerMessage.GetPropertiesIns -) -> ClientMessage: - # Deserialize `get_properties` instruction - get_properties_ins = serde.get_properties_ins_from_proto(get_properties_msg) - - # Request properties - get_properties_res = maybe_call_get_properties( - client=client, - get_properties_ins=get_properties_ins, - ) - - # Serialize response - get_properties_res_proto = serde.get_properties_res_to_proto(get_properties_res) - return ClientMessage(get_properties_res=get_properties_res_proto) - - -def _get_parameters( - client: Client, get_parameters_msg: ServerMessage.GetParametersIns -) -> ClientMessage: - # Deserialize `get_parameters` instruction - get_parameters_ins = serde.get_parameters_ins_from_proto(get_parameters_msg) - - # Request parameters - get_parameters_res = maybe_call_get_parameters( - client=client, - get_parameters_ins=get_parameters_ins, - ) - - # Serialize response - get_parameters_res_proto = serde.get_parameters_res_to_proto(get_parameters_res) - return ClientMessage(get_parameters_res=get_parameters_res_proto) - - -def _fit(client: Client, fit_msg: ServerMessage.FitIns) -> ClientMessage: - # Deserialize fit instruction - fit_ins = serde.fit_ins_from_proto(fit_msg) - - # Perform fit - fit_res = maybe_call_fit( - client=client, - fit_ins=fit_ins, - ) - - # Serialize fit result - fit_res_proto = serde.fit_res_to_proto(fit_res) - return ClientMessage(fit_res=fit_res_proto) - - -def _evaluate(client: Client, evaluate_msg: ServerMessage.EvaluateIns) -> ClientMessage: - # Deserialize evaluate instruction - evaluate_ins = serde.evaluate_ins_from_proto(evaluate_msg) - - # Perform evaluation - evaluate_res = maybe_call_evaluate( - client=client, - evaluate_ins=evaluate_ins, - ) - - # Serialize evaluate result - evaluate_res_proto = serde.evaluate_res_to_proto(evaluate_res) - return ClientMessage(evaluate_res=evaluate_res_proto) +def validate_out_message(out_message: Message, in_message_metadata: Metadata) -> bool: + """Validate the out message.""" + out_meta = out_message.metadata + in_meta = in_message_metadata + if ( # pylint: disable-next=too-many-boolean-expressions + out_meta.run_id == in_meta.run_id + and out_meta.message_id == "" # This will be generated by the server + and out_meta.src_node_id == in_meta.dst_node_id + and out_meta.dst_node_id == in_meta.src_node_id + and out_meta.reply_to_message == in_meta.message_id + and out_meta.group_id == in_meta.group_id + and out_meta.message_type == in_meta.message_type + ): + return True + return False diff --git a/src/py/flwr/client/message_handler/message_handler_test.py b/src/py/flwr/client/message_handler/message_handler_test.py index d7f410d81fc0..eaf16f7dc993 100644 --- a/src/py/flwr/client/message_handler/message_handler_test.py +++ b/src/py/flwr/client/message_handler/message_handler_test.py @@ -15,12 +15,16 @@ """Client-side message handler tests.""" +import unittest import uuid +from copy import copy +from typing import List from flwr.client import Client from flwr.client.typing import ClientFn -from flwr.client.workload_state import WorkloadState from flwr.common import ( + Code, + Context, EvaluateIns, EvaluateRes, FitIns, @@ -29,15 +33,17 @@ GetParametersRes, GetPropertiesIns, GetPropertiesRes, + Message, + Metadata, Parameters, - serde, - typing, + RecordSet, + Status, ) -from flwr.proto.node_pb2 import Node -from flwr.proto.task_pb2 import Task, TaskIns, TaskRes -from flwr.proto.transport_pb2 import ClientMessage, Code, ServerMessage, Status +from flwr.common import recordset_compat as compat +from flwr.common import typing +from flwr.common.constant import MessageTypeLegacy -from .message_handler import handle, handle_control_message +from .message_handler import handle_legacy_message_from_msgtype, validate_out_message class ClientWithoutProps(Client): @@ -116,137 +122,169 @@ def test_client_without_get_properties() -> None: """Test client implementing get_properties.""" # Prepare client = ClientWithoutProps() - ins = ServerMessage.GetPropertiesIns() - - task_ins: TaskIns = TaskIns( - task_id=str(uuid.uuid4()), - group_id="", - workload_id=0, - task=Task( - producer=Node(node_id=0, anonymous=True), - consumer=Node(node_id=0, anonymous=True), - ancestry=[], - legacy_server_message=ServerMessage(get_properties_ins=ins), + recordset = compat.getpropertiesins_to_recordset(GetPropertiesIns({})) + message = Message( + metadata=Metadata( + run_id=123, + message_id=str(uuid.uuid4()), + group_id="some group ID", + src_node_id=0, + dst_node_id=1123, + reply_to_message="", + ttl="", + message_type=MessageTypeLegacy.GET_PROPERTIES, ), + content=recordset, ) # Execute - disconnect_task_res, actual_sleep_duration = handle_control_message( - task_ins=task_ins - ) - task_res, _ = handle( + actual_msg = handle_legacy_message_from_msgtype( client_fn=_get_client_fn(client), - state=WorkloadState(state={}), - task_ins=task_ins, + message=message, + context=Context(state=RecordSet()), ) - if not task_res.HasField("task"): - raise ValueError("Task value not found") - - # pylint: disable=no-member - if not task_res.task.HasField("legacy_client_message"): - raise ValueError("Unexpected None value") - # pylint: enable=no-member - - task_res.MergeFrom( - TaskRes( - task_id=str(uuid.uuid4()), - group_id="", - workload_id=0, - ) - ) - # pylint: disable=no-member - task_res.task.MergeFrom( - Task( - producer=Node(node_id=0, anonymous=True), - consumer=Node(node_id=0, anonymous=True), - ancestry=[task_ins.task_id], - ) - ) - - actual_msg = task_res.task.legacy_client_message - # pylint: enable=no-member - # Assert - expected_get_properties_res = ClientMessage.GetPropertiesRes( + expected_get_properties_res = GetPropertiesRes( status=Status( code=Code.GET_PROPERTIES_NOT_IMPLEMENTED, message="Client does not implement `get_properties`", - ) + ), + properties={}, + ) + expected_rs = compat.getpropertiesres_to_recordset(expected_get_properties_res) + expected_msg = Message( + metadata=Metadata( + run_id=123, + message_id="", + group_id="some group ID", + src_node_id=1123, + dst_node_id=0, + reply_to_message=message.metadata.message_id, + ttl="", + message_type=MessageTypeLegacy.GET_PROPERTIES, + ), + content=expected_rs, ) - expected_msg = ClientMessage(get_properties_res=expected_get_properties_res) - assert actual_msg == expected_msg - assert not disconnect_task_res - assert actual_sleep_duration == 0 + assert actual_msg.content == expected_msg.content + assert actual_msg.metadata == expected_msg.metadata def test_client_with_get_properties() -> None: """Test client not implementing get_properties.""" # Prepare client = ClientWithProps() - ins = ServerMessage.GetPropertiesIns() - task_ins = TaskIns( - task_id=str(uuid.uuid4()), - group_id="", - workload_id=0, - task=Task( - producer=Node(node_id=0, anonymous=True), - consumer=Node(node_id=0, anonymous=True), - ancestry=[], - legacy_server_message=ServerMessage(get_properties_ins=ins), + recordset = compat.getpropertiesins_to_recordset(GetPropertiesIns({})) + message = Message( + metadata=Metadata( + run_id=123, + message_id=str(uuid.uuid4()), + group_id="some group ID", + src_node_id=0, + dst_node_id=1123, + reply_to_message="", + ttl="", + message_type=MessageTypeLegacy.GET_PROPERTIES, ), + content=recordset, ) # Execute - disconnect_task_res, actual_sleep_duration = handle_control_message( - task_ins=task_ins - ) - task_res, _ = handle( + actual_msg = handle_legacy_message_from_msgtype( client_fn=_get_client_fn(client), - state=WorkloadState(state={}), - task_ins=task_ins, - ) - - if not task_res.HasField("task"): - raise ValueError("Task value not found") - - # pylint: disable=no-member - if not task_res.task.HasField("legacy_client_message"): - raise ValueError("Unexpected None value") - # pylint: enable=no-member - - task_res.MergeFrom( - TaskRes( - task_id=str(uuid.uuid4()), - group_id="", - workload_id=0, - ) + message=message, + context=Context(state=RecordSet()), ) - # pylint: disable=no-member - task_res.task.MergeFrom( - Task( - producer=Node(node_id=0, anonymous=True), - consumer=Node(node_id=0, anonymous=True), - ancestry=[task_ins.task_id], - ) - ) - - actual_msg = task_res.task.legacy_client_message - # pylint: enable=no-member # Assert - expected_get_properties_res = ClientMessage.GetPropertiesRes( + expected_get_properties_res = GetPropertiesRes( status=Status( code=Code.OK, message="Success", ), - properties=serde.properties_to_proto( - properties={"str_prop": "val", "int_prop": 1} + properties={"str_prop": "val", "int_prop": 1}, + ) + expected_rs = compat.getpropertiesres_to_recordset(expected_get_properties_res) + expected_msg = Message( + metadata=Metadata( + run_id=123, + message_id="", + group_id="some group ID", + src_node_id=1123, + dst_node_id=0, + reply_to_message=message.metadata.message_id, + ttl="", + message_type=MessageTypeLegacy.GET_PROPERTIES, ), + content=expected_rs, ) - expected_msg = ClientMessage(get_properties_res=expected_get_properties_res) - assert actual_msg == expected_msg - assert not disconnect_task_res - assert actual_sleep_duration == 0 + assert actual_msg.content == expected_msg.content + assert actual_msg.metadata == expected_msg.metadata + + +class TestMessageValidation(unittest.TestCase): + """Test message validation.""" + + def setUp(self) -> None: + """Set up the message validation.""" + # Common setup for tests + self.in_metadata = Metadata( + run_id=123, + message_id="qwerty", + src_node_id=10, + dst_node_id=20, + reply_to_message="", + group_id="group1", + ttl="60", + message_type="mock", + ) + self.valid_out_metadata = Metadata( + run_id=123, + message_id="", + src_node_id=20, + dst_node_id=10, + reply_to_message="qwerty", + group_id="group1", + ttl="60", + message_type="mock", + ) + self.common_content = RecordSet() + + def test_valid_message(self) -> None: + """Test a valid message.""" + # Prepare + valid_message = Message(metadata=self.valid_out_metadata, content=RecordSet()) + + # Assert + self.assertTrue(validate_out_message(valid_message, self.in_metadata)) + + def test_invalid_message_run_id(self) -> None: + """Test invalid messages.""" + # Prepare + msg = Message(metadata=self.valid_out_metadata, content=RecordSet()) + + # Execute + invalid_metadata_list: List[Metadata] = [] + attrs = list(vars(self.valid_out_metadata).keys()) + for attr in attrs: + if attr == "_partition_id": + continue + if attr == "_ttl": # Skip configurable ttl + continue + # Make an invalid metadata + invalid_metadata = copy(self.valid_out_metadata) + value = getattr(invalid_metadata, attr) + if isinstance(value, int): + value = 999 + elif isinstance(value, str): + value = "999" + setattr(invalid_metadata, attr, value) + # Add to list + invalid_metadata_list.append(invalid_metadata) + + # Assert + for invalid_metadata in invalid_metadata_list: + msg._metadata = invalid_metadata # pylint: disable=protected-access + self.assertFalse(validate_out_message(msg, self.in_metadata)) diff --git a/src/py/flwr/client/message_handler/task_handler.py b/src/py/flwr/client/message_handler/task_handler.py index fc24539998c0..7f515a30fe5a 100644 --- a/src/py/flwr/client/message_handler/task_handler.py +++ b/src/py/flwr/client/message_handler/task_handler.py @@ -17,21 +17,17 @@ from typing import Optional -from flwr.proto.fleet_pb2 import PullTaskInsResponse -from flwr.proto.node_pb2 import Node -from flwr.proto.task_pb2 import Task, TaskIns, TaskRes -from flwr.proto.transport_pb2 import ClientMessage, ServerMessage +from flwr.proto.fleet_pb2 import PullTaskInsResponse # pylint: disable=E0611 +from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611 -def validate_task_ins(task_ins: TaskIns, discard_reconnect_ins: bool) -> bool: +def validate_task_ins(task_ins: TaskIns) -> bool: """Validate a TaskIns before it entering the message handling process. Parameters ---------- task_ins: TaskIns The task instruction coming from the server. - discard_reconnect_ins: bool - If True, ReconnectIns will not be considered as valid content. Returns ------- @@ -39,58 +35,8 @@ def validate_task_ins(task_ins: TaskIns, discard_reconnect_ins: bool) -> bool: True if the TaskIns is deemed valid and therefore suitable for undergoing the message handling process, False otherwise. """ - # Check if the task_ins contains legacy_server_message or sa. - # If legacy_server_message is set, check if ServerMessage is one of - # {GetPropertiesIns, GetParametersIns, FitIns, EvaluateIns, ReconnectIns*} - # Discard ReconnectIns if discard_reconnect_ins is true. - if ( - not task_ins.HasField("task") - or ( - not task_ins.task.HasField("legacy_server_message") - and not task_ins.task.HasField("sa") - ) - or ( - discard_reconnect_ins - and task_ins.task.legacy_server_message.WhichOneof("msg") == "reconnect_ins" - ) - ): + if not (task_ins.HasField("task") and task_ins.task.HasField("recordset")): return False - - return True - - -def validate_task_res(task_res: TaskRes) -> bool: - """Validate a TaskRes before filling its fields in the `send()` function. - - Parameters - ---------- - task_res: TaskRes - The task response to be sent to the server. - - Returns - ------- - is_valid: bool - True if the `task_id`, `group_id`, and `workload_id` fields in TaskRes - and the `producer`, `consumer`, and `ancestry` fields in its sub-message Task - are not initialized accidentally elsewhere, - False otherwise. - """ - # Retrieve initialized fields in TaskRes and Task - initialized_fields_in_task_res = {field.name for field, _ in task_res.ListFields()} - initialized_fields_in_task = {field.name for field, _ in task_res.task.ListFields()} - - # Check if certain fields are already initialized - # pylint: disable-next=too-many-boolean-expressions - if ( - "task_id" in initialized_fields_in_task_res - or "group_id" in initialized_fields_in_task_res - or "workload_id" in initialized_fields_in_task_res - or "producer" in initialized_fields_in_task - or "consumer" in initialized_fields_in_task - or "ancestry" in initialized_fields_in_task - ): - return False - return True @@ -106,61 +52,3 @@ def get_task_ins( task_ins: TaskIns = pull_task_ins_response.task_ins_list[0] return task_ins - - -def get_server_message_from_task_ins( - task_ins: TaskIns, exclude_reconnect_ins: bool -) -> Optional[ServerMessage]: - """Get ServerMessage from TaskIns, if available.""" - # Return the message if it is in - # {GetPropertiesIns, GetParametersIns, FitIns, EvaluateIns} - # Return the message if it is ReconnectIns and exclude_reconnect_ins is False. - if not validate_task_ins( - task_ins, discard_reconnect_ins=exclude_reconnect_ins - ) or not task_ins.task.HasField("legacy_server_message"): - return None - - return task_ins.task.legacy_server_message - - -def wrap_client_message_in_task_res(client_message: ClientMessage) -> TaskRes: - """Wrap ClientMessage in TaskRes.""" - # Instantiate a TaskRes, only filling client_message field. - return TaskRes( - task_id="", - group_id="", - workload_id=0, - task=Task(ancestry=[], legacy_client_message=client_message), - ) - - -def configure_task_res( - task_res: TaskRes, ref_task_ins: TaskIns, producer: Node -) -> TaskRes: - """Set the metadata of a TaskRes. - - Fill `group_id` and `workload_id` in TaskRes - and `producer`, `consumer`, and `ancestry` in Task in TaskRes. - - `producer` in Task in TaskRes will remain unchanged/unset. - - Note that protobuf API `protobuf.message.MergeFrom(other_msg)` - does NOT always overwrite fields that are set in `other_msg`. - Please refer to: - https://googleapis.dev/python/protobuf/latest/google/protobuf/message.html - """ - task_res = TaskRes( - task_id="", # This will be generated by the server - group_id=ref_task_ins.group_id, - workload_id=ref_task_ins.workload_id, - task=task_res.task, - ) - # pylint: disable-next=no-member - task_res.task.MergeFrom( - Task( - producer=producer, - consumer=ref_task_ins.task.producer, - ancestry=[ref_task_ins.task_id], - ) - ) - return task_res diff --git a/src/py/flwr/client/message_handler/task_handler_test.py b/src/py/flwr/client/message_handler/task_handler_test.py index 21f3a2ead98a..c8b9e14737ff 100644 --- a/src/py/flwr/client/message_handler/task_handler_test.py +++ b/src/py/flwr/client/message_handler/task_handler_test.py @@ -15,100 +15,31 @@ """Tests for module task_handler.""" -from flwr.client.message_handler.task_handler import ( - get_server_message_from_task_ins, - get_task_ins, - validate_task_ins, - validate_task_res, - wrap_client_message_in_task_res, -) -from flwr.proto.fleet_pb2 import PullTaskInsResponse -from flwr.proto.task_pb2 import SecureAggregation, Task, TaskIns, TaskRes -from flwr.proto.transport_pb2 import ClientMessage, ServerMessage +from flwr.client.message_handler.task_handler import get_task_ins, validate_task_ins +from flwr.common import RecordSet, serde +from flwr.proto.fleet_pb2 import PullTaskInsResponse # pylint: disable=E0611 +from flwr.proto.task_pb2 import Task, TaskIns # pylint: disable=E0611 def test_validate_task_ins_no_task() -> None: """Test validate_task_ins.""" task_ins = TaskIns(task=None) - assert not validate_task_ins(task_ins, discard_reconnect_ins=True) - assert not validate_task_ins(task_ins, discard_reconnect_ins=False) + assert not validate_task_ins(task_ins) def test_validate_task_ins_no_content() -> None: """Test validate_task_ins.""" - task_ins = TaskIns(task=Task(legacy_server_message=None, sa=None)) + task_ins = TaskIns(task=Task(recordset=None)) - assert not validate_task_ins(task_ins, discard_reconnect_ins=True) - assert not validate_task_ins(task_ins, discard_reconnect_ins=False) + assert not validate_task_ins(task_ins) -def test_validate_task_ins_with_reconnect_ins() -> None: +def test_validate_task_ins_valid() -> None: """Test validate_task_ins.""" - task_ins = TaskIns( - task=Task( - legacy_server_message=ServerMessage( - reconnect_ins=ServerMessage.ReconnectIns(seconds=3) - ) - ) - ) - - assert not validate_task_ins(task_ins, discard_reconnect_ins=True) - assert validate_task_ins(task_ins, discard_reconnect_ins=False) - - -def test_validate_task_ins_valid_legacy_server_message() -> None: - """Test validate_task_ins.""" - task_ins = TaskIns( - task=Task( - legacy_server_message=ServerMessage( - get_properties_ins=ServerMessage.GetPropertiesIns() - ) - ) - ) - - assert validate_task_ins(task_ins, discard_reconnect_ins=True) - assert validate_task_ins(task_ins, discard_reconnect_ins=False) - - -def test_validate_task_ins_valid_sa() -> None: - """Test validate_task_ins.""" - task_ins = TaskIns(task=Task(sa=SecureAggregation())) - - assert validate_task_ins(task_ins, discard_reconnect_ins=True) - assert validate_task_ins(task_ins, discard_reconnect_ins=False) - - -def test_validate_task_res() -> None: - """Test validate_task_res.""" - task_res = TaskRes(task=Task()) - assert validate_task_res(task_res) - - task_res.task_id = "123" - assert not validate_task_res(task_res) - - task_res.Clear() - task_res.group_id = "123" - assert not validate_task_res(task_res) - - task_res.Clear() - task_res.workload_id = 61016 - assert not validate_task_res(task_res) - - task_res.Clear() - # pylint: disable-next=no-member - task_res.task.producer.node_id = 0 - assert not validate_task_res(task_res) - - task_res.Clear() - # pylint: disable-next=no-member - task_res.task.consumer.node_id = 0 - assert not validate_task_res(task_res) + task_ins = TaskIns(task=Task(recordset=serde.recordset_to_proto(RecordSet()))) - task_res.Clear() - # pylint: disable-next=no-member - task_res.task.ancestry.append("123") - assert not validate_task_res(task_res) + assert validate_task_ins(task_ins) def test_get_task_ins_empty_response() -> None: @@ -134,61 +65,3 @@ def test_get_task_ins_multiple_ins() -> None: ) actual_task_ins = get_task_ins(res) assert actual_task_ins == expected_task_ins - - -def test_get_server_message_from_task_ins_invalid() -> None: - """Test get_server_message_from_task_ins.""" - task_ins = TaskIns(task=Task(legacy_server_message=None)) - msg_t = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=True) - msg_f = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=False) - - assert msg_t is None - assert msg_f is None - - -def test_get_server_message_from_task_ins_reconnect_ins() -> None: - """Test get_server_message_from_task_ins.""" - expected_server_message = ServerMessage( - reconnect_ins=ServerMessage.ReconnectIns(seconds=3) - ) - task_ins = TaskIns(task=Task(legacy_server_message=expected_server_message)) - msg_t = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=True) - msg_f = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=False) - - assert msg_t is None - assert msg_f == expected_server_message - - -def test_get_server_message_from_task_ins_sa() -> None: - """Test get_server_message_from_task_ins.""" - task_ins = TaskIns(task=Task(sa=SecureAggregation())) - msg_t = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=True) - msg_f = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=False) - - assert msg_t is None - assert msg_f is None - - -def test_get_server_message_from_task_ins_valid_legacy_server_message() -> None: - """Test get_server_message_from_task_ins.""" - expected_server_message = ServerMessage( - get_properties_ins=ServerMessage.GetPropertiesIns() - ) - task_ins = TaskIns(task=Task(legacy_server_message=expected_server_message)) - msg_t = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=True) - msg_f = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=False) - - assert msg_t == expected_server_message - assert msg_f == expected_server_message - - -def test_wrap_client_message_in_task_res() -> None: - """Test wrap_client_message_in_task_res.""" - expected_client_message = ClientMessage( - get_properties_res=ClientMessage.GetPropertiesRes() - ) - task_res = wrap_client_message_in_task_res(expected_client_message) - - assert validate_task_res(task_res) - # pylint: disable-next=no-member - assert task_res.task.legacy_client_message == expected_client_message diff --git a/src/py/flwr/client/mod/__init__.py b/src/py/flwr/client/mod/__init__.py new file mode 100644 index 000000000000..1cd79fa944fe --- /dev/null +++ b/src/py/flwr/client/mod/__init__.py @@ -0,0 +1,33 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Mods.""" + + +from .centraldp_mods import adaptiveclipping_mod, fixedclipping_mod +from .comms_mods import message_size_mod, parameters_size_mod +from .localdp_mod import LocalDpMod +from .secure_aggregation import secagg_mod, secaggplus_mod +from .utils import make_ffn + +__all__ = [ + "adaptiveclipping_mod", + "fixedclipping_mod", + "LocalDpMod", + "make_ffn", + "secagg_mod", + "secaggplus_mod", + "message_size_mod", + "parameters_size_mod", +] diff --git a/src/py/flwr/client/mod/centraldp_mods.py b/src/py/flwr/client/mod/centraldp_mods.py new file mode 100644 index 000000000000..4f4a595e8d9c --- /dev/null +++ b/src/py/flwr/client/mod/centraldp_mods.py @@ -0,0 +1,157 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Clipping modifiers for central DP with client-side clipping.""" + + +from logging import INFO + +from flwr.client.typing import ClientAppCallable +from flwr.common import ndarrays_to_parameters, parameters_to_ndarrays +from flwr.common import recordset_compat as compat +from flwr.common.constant import MessageType +from flwr.common.context import Context +from flwr.common.differential_privacy import ( + compute_adaptive_clip_model_update, + compute_clip_model_update, +) +from flwr.common.differential_privacy_constants import KEY_CLIPPING_NORM, KEY_NORM_BIT +from flwr.common.logger import log +from flwr.common.message import Message + + +def fixedclipping_mod( + msg: Message, ctxt: Context, call_next: ClientAppCallable +) -> Message: + """Client-side fixed clipping modifier. + + This mod needs to be used with the DifferentialPrivacyClientSideFixedClipping + server-side strategy wrapper. + + The wrapper sends the clipping_norm value to the client. + + This mod clips the client model updates before sending them to the server. + + It operates on messages of type `MessageType.TRAIN`. + + Notes + ----- + Consider the order of mods when using multiple. + + Typically, fixedclipping_mod should be the last to operate on params. + """ + if msg.metadata.message_type != MessageType.TRAIN: + return call_next(msg, ctxt) + fit_ins = compat.recordset_to_fitins(msg.content, keep_input=True) + if KEY_CLIPPING_NORM not in fit_ins.config: + raise KeyError( + f"The {KEY_CLIPPING_NORM} value is not supplied by the " + f"DifferentialPrivacyClientSideFixedClipping wrapper at" + f" the server side." + ) + + clipping_norm = float(fit_ins.config[KEY_CLIPPING_NORM]) + server_to_client_params = parameters_to_ndarrays(fit_ins.parameters) + + # Call inner app + out_msg = call_next(msg, ctxt) + + # Check if the msg has error + if out_msg.has_error(): + return out_msg + + fit_res = compat.recordset_to_fitres(out_msg.content, keep_input=True) + + client_to_server_params = parameters_to_ndarrays(fit_res.parameters) + + # Clip the client update + compute_clip_model_update( + client_to_server_params, + server_to_client_params, + clipping_norm, + ) + + log(INFO, "fixedclipping_mod: parameters are clipped by value: %s.", clipping_norm) + + fit_res.parameters = ndarrays_to_parameters(client_to_server_params) + out_msg.content = compat.fitres_to_recordset(fit_res, keep_input=True) + return out_msg + + +def adaptiveclipping_mod( + msg: Message, ctxt: Context, call_next: ClientAppCallable +) -> Message: + """Client-side adaptive clipping modifier. + + This mod needs to be used with the DifferentialPrivacyClientSideAdaptiveClipping + server-side strategy wrapper. + + The wrapper sends the clipping_norm value to the client. + + This mod clips the client model updates before sending them to the server. + + It also sends KEY_NORM_BIT to the server for computing the new clipping value. + + It operates on messages of type `MessageType.TRAIN`. + + Notes + ----- + Consider the order of mods when using multiple. + + Typically, adaptiveclipping_mod should be the last to operate on params. + """ + if msg.metadata.message_type != MessageType.TRAIN: + return call_next(msg, ctxt) + + fit_ins = compat.recordset_to_fitins(msg.content, keep_input=True) + + if KEY_CLIPPING_NORM not in fit_ins.config: + raise KeyError( + f"The {KEY_CLIPPING_NORM} value is not supplied by the " + f"DifferentialPrivacyClientSideFixedClipping wrapper at" + f" the server side." + ) + if not isinstance(fit_ins.config[KEY_CLIPPING_NORM], float): + raise ValueError(f"{KEY_CLIPPING_NORM} should be a float value.") + clipping_norm = float(fit_ins.config[KEY_CLIPPING_NORM]) + server_to_client_params = parameters_to_ndarrays(fit_ins.parameters) + + # Call inner app + out_msg = call_next(msg, ctxt) + + # Check if the msg has error + if out_msg.has_error(): + return out_msg + + fit_res = compat.recordset_to_fitres(out_msg.content, keep_input=True) + + client_to_server_params = parameters_to_ndarrays(fit_res.parameters) + + # Clip the client update + norm_bit = compute_adaptive_clip_model_update( + client_to_server_params, + server_to_client_params, + clipping_norm, + ) + log( + INFO, + "adaptiveclipping_mod: parameters are clipped by value: %s.", + clipping_norm, + ) + + fit_res.parameters = ndarrays_to_parameters(client_to_server_params) + + fit_res.metrics[KEY_NORM_BIT] = norm_bit + out_msg.content = compat.fitres_to_recordset(fit_res, keep_input=True) + return out_msg diff --git a/src/py/flwr/client/mod/comms_mods.py b/src/py/flwr/client/mod/comms_mods.py new file mode 100644 index 000000000000..102d2f477262 --- /dev/null +++ b/src/py/flwr/client/mod/comms_mods.py @@ -0,0 +1,79 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Mods that report statistics about message communication.""" + +from logging import INFO + +import numpy as np + +from flwr.client.typing import ClientAppCallable +from flwr.common.context import Context +from flwr.common.logger import log +from flwr.common.message import Message + + +def message_size_mod( + msg: Message, ctxt: Context, call_next: ClientAppCallable +) -> Message: + """Message size mod. + + This mod logs the size in Bytes of the message being transmited. + """ + message_size_in_bytes = 0 + + for p_record in msg.content.parameters_records.values(): + message_size_in_bytes += p_record.count_bytes() + + for c_record in msg.content.configs_records.values(): + message_size_in_bytes += c_record.count_bytes() + + for m_record in msg.content.metrics_records.values(): + message_size_in_bytes += m_record.count_bytes() + + log(INFO, "Message size: %i Bytes", message_size_in_bytes) + + return call_next(msg, ctxt) + + +def parameters_size_mod( + msg: Message, ctxt: Context, call_next: ClientAppCallable +) -> Message: + """Parameters size mod. + + This mod logs the number of parameters transmitted in the message as well as their + size in Bytes. + """ + model_size_stats = {} + parameters_size_in_bytes = 0 + for record_name, p_record in msg.content.parameters_records.items(): + p_record_bytes = p_record.count_bytes() + parameters_size_in_bytes += p_record_bytes + parameter_count = 0 + for array in p_record.values(): + parameter_count += ( + int(np.prod(array.shape)) if array.shape else array.numpy().size + ) + + model_size_stats[f"{record_name}"] = { + "parameters": parameter_count, + "bytes": p_record_bytes, + } + + if model_size_stats: + log(INFO, model_size_stats) + + log(INFO, "Total parameters transmited: %i Bytes", parameters_size_in_bytes) + + return call_next(msg, ctxt) diff --git a/src/py/flwr/client/mod/localdp_mod.py b/src/py/flwr/client/mod/localdp_mod.py new file mode 100644 index 000000000000..3b0311a612b9 --- /dev/null +++ b/src/py/flwr/client/mod/localdp_mod.py @@ -0,0 +1,148 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Local DP modifier.""" + + +from logging import INFO + +import numpy as np + +from flwr.client.typing import ClientAppCallable +from flwr.common import ndarrays_to_parameters, parameters_to_ndarrays +from flwr.common import recordset_compat as compat +from flwr.common.constant import MessageType +from flwr.common.context import Context +from flwr.common.differential_privacy import ( + add_localdp_gaussian_noise_to_params, + compute_clip_model_update, +) +from flwr.common.logger import log +from flwr.common.message import Message + + +class LocalDpMod: + """Modifier for local differential privacy. + + This mod clips the client model updates and + adds noise to the params before sending them to the server. + + It operates on messages of type `MessageType.TRAIN`. + + Parameters + ---------- + clipping_norm : float + The value of the clipping norm. + sensitivity : float + The sensitivity of the client model. + epsilon : float + The privacy budget. + Smaller value of epsilon indicates a higher level of privacy protection. + delta : float + The failure probability. + The probability that the privacy mechanism + fails to provide the desired level of privacy. + A smaller value of delta indicates a stricter privacy guarantee. + + Examples + -------- + Create an instance of the local DP mod and add it to the client-side mods: + + >>> local_dp_mod = LocalDpMod( ... ) + >>> app = fl.client.ClientApp( + >>> client_fn=client_fn, mods=[local_dp_mod] + >>> ) + """ + + def __init__( + self, clipping_norm: float, sensitivity: float, epsilon: float, delta: float + ) -> None: + if clipping_norm <= 0: + raise ValueError("The clipping norm should be a positive value.") + + if sensitivity < 0: + raise ValueError("The sensitivity should be a non-negative value.") + + if epsilon < 0: + raise ValueError("Epsilon should be a non-negative value.") + + if delta < 0: + raise ValueError("Delta should be a non-negative value.") + + self.clipping_norm = clipping_norm + self.sensitivity = sensitivity + self.epsilon = epsilon + self.delta = delta + + def __call__( + self, msg: Message, ctxt: Context, call_next: ClientAppCallable + ) -> Message: + """Perform local DP on the client model parameters. + + Parameters + ---------- + msg : Message + The message received from the server. + ctxt : Context + The context of the client. + call_next : ClientAppCallable + The callable to call the next middleware in the chain. + + Returns + ------- + Message + The modified message to be sent back to the server. + """ + if msg.metadata.message_type != MessageType.TRAIN: + return call_next(msg, ctxt) + + fit_ins = compat.recordset_to_fitins(msg.content, keep_input=True) + server_to_client_params = parameters_to_ndarrays(fit_ins.parameters) + + # Call inner app + out_msg = call_next(msg, ctxt) + + # Check if the msg has error + if out_msg.has_error(): + return out_msg + + fit_res = compat.recordset_to_fitres(out_msg.content, keep_input=True) + + client_to_server_params = parameters_to_ndarrays(fit_res.parameters) + + # Clip the client update + compute_clip_model_update( + client_to_server_params, + server_to_client_params, + self.clipping_norm, + ) + log( + INFO, "LocalDpMod: parameters are clipped by value: %s.", self.clipping_norm + ) + + fit_res.parameters = ndarrays_to_parameters(client_to_server_params) + + # Add noise to model params + add_localdp_gaussian_noise_to_params( + fit_res.parameters, self.sensitivity, self.epsilon, self.delta + ) + log( + INFO, + "LocalDpMod: local DP noise with " + "standard deviation: %s added to parameters.", + self.sensitivity * np.sqrt(2 * np.log(1.25 / self.delta)) / self.epsilon, + ) + + out_msg.content = compat.fitres_to_recordset(fit_res, keep_input=True) + return out_msg diff --git a/src/py/flwr/client/workload_state.py b/src/py/flwr/client/mod/secure_aggregation/__init__.py similarity index 77% rename from src/py/flwr/client/workload_state.py rename to src/py/flwr/client/mod/secure_aggregation/__init__.py index 42ae2a925f47..8892d8c03935 100644 --- a/src/py/flwr/client/workload_state.py +++ b/src/py/flwr/client/mod/secure_aggregation/__init__.py @@ -12,14 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Workload state.""" +"""Secure Aggregation mods.""" -from dataclasses import dataclass -from typing import Dict +from .secagg_mod import secagg_mod +from .secaggplus_mod import secaggplus_mod -@dataclass -class WorkloadState: - """State of a workload executed by a client node.""" - - state: Dict[str, str] +__all__ = [ + "secagg_mod", + "secaggplus_mod", +] diff --git a/src/py/flwr/client/mod/secure_aggregation/secagg_mod.py b/src/py/flwr/client/mod/secure_aggregation/secagg_mod.py new file mode 100644 index 000000000000..d87af59a4e6e --- /dev/null +++ b/src/py/flwr/client/mod/secure_aggregation/secagg_mod.py @@ -0,0 +1,30 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Modifier for the SecAgg protocol.""" + + +from flwr.client.typing import ClientAppCallable +from flwr.common import Context, Message + +from .secaggplus_mod import secaggplus_mod + + +def secagg_mod( + msg: Message, + ctxt: Context, + call_next: ClientAppCallable, +) -> Message: + """Handle incoming message and return results, following the SecAgg protocol.""" + return secaggplus_mod(msg, ctxt, call_next) diff --git a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py new file mode 100644 index 000000000000..989d5f6e1361 --- /dev/null +++ b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py @@ -0,0 +1,518 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Modifier for the SecAgg+ protocol.""" + + +import os +from dataclasses import dataclass, field +from logging import DEBUG, WARNING +from typing import Any, Dict, List, Tuple, cast + +from flwr.client.typing import ClientAppCallable +from flwr.common import ( + ConfigsRecord, + Context, + Message, + Parameters, + RecordSet, + ndarray_to_bytes, + parameters_to_ndarrays, +) +from flwr.common import recordset_compat as compat +from flwr.common.constant import MessageType +from flwr.common.logger import log +from flwr.common.secure_aggregation.crypto.shamir import create_shares +from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( + bytes_to_private_key, + bytes_to_public_key, + decrypt, + encrypt, + generate_key_pairs, + generate_shared_key, + private_key_to_bytes, + public_key_to_bytes, +) +from flwr.common.secure_aggregation.ndarrays_arithmetic import ( + factor_combine, + parameters_addition, + parameters_mod, + parameters_multiply, + parameters_subtraction, +) +from flwr.common.secure_aggregation.quantization import quantize +from flwr.common.secure_aggregation.secaggplus_constants import ( + RECORD_KEY_CONFIGS, + RECORD_KEY_STATE, + Key, + Stage, +) +from flwr.common.secure_aggregation.secaggplus_utils import ( + pseudo_rand_gen, + share_keys_plaintext_concat, + share_keys_plaintext_separate, +) +from flwr.common.typing import ConfigsRecordValues + + +@dataclass +# pylint: disable-next=too-many-instance-attributes +class SecAggPlusState: + """State of the SecAgg+ protocol.""" + + current_stage: str = Stage.UNMASK + + nid: int = 0 + sample_num: int = 0 + share_num: int = 0 + threshold: int = 0 + clipping_range: float = 0.0 + target_range: int = 0 + mod_range: int = 0 + max_weight: float = 0.0 + + # Secret key (sk) and public key (pk) + sk1: bytes = b"" + pk1: bytes = b"" + sk2: bytes = b"" + pk2: bytes = b"" + + # Random seed for generating the private mask + rd_seed: bytes = b"" + + rd_seed_share_dict: Dict[int, bytes] = field(default_factory=dict) + sk1_share_dict: Dict[int, bytes] = field(default_factory=dict) + # The dict of the shared secrets from sk2 + ss2_dict: Dict[int, bytes] = field(default_factory=dict) + public_keys_dict: Dict[int, Tuple[bytes, bytes]] = field(default_factory=dict) + + def __init__(self, **kwargs: ConfigsRecordValues) -> None: + for k, v in kwargs.items(): + if k.endswith(":V"): + continue + new_v: Any = v + if k.endswith(":K"): + k = k[:-2] + keys = cast(List[int], v) + values = cast(List[bytes], kwargs[f"{k}:V"]) + if len(values) > len(keys): + updated_values = [ + tuple(values[i : i + 2]) for i in range(0, len(values), 2) + ] + new_v = dict(zip(keys, updated_values)) + else: + new_v = dict(zip(keys, values)) + self.__setattr__(k, new_v) + + def to_dict(self) -> Dict[str, ConfigsRecordValues]: + """Convert the state to a dictionary.""" + ret = vars(self) + for k in list(ret.keys()): + if isinstance(ret[k], dict): + # Replace dict with two lists + v = cast(Dict[str, Any], ret.pop(k)) + ret[f"{k}:K"] = list(v.keys()) + if k == "public_keys_dict": + v_list: List[bytes] = [] + for b1_b2 in cast(List[Tuple[bytes, bytes]], v.values()): + v_list.extend(b1_b2) + ret[f"{k}:V"] = v_list + else: + ret[f"{k}:V"] = list(v.values()) + return ret + + +def secaggplus_mod( + msg: Message, + ctxt: Context, + call_next: ClientAppCallable, +) -> Message: + """Handle incoming message and return results, following the SecAgg+ protocol.""" + # Ignore non-fit messages + if msg.metadata.message_type != MessageType.TRAIN: + return call_next(msg, ctxt) + + # Retrieve local state + if RECORD_KEY_STATE not in ctxt.state.configs_records: + ctxt.state.configs_records[RECORD_KEY_STATE] = ConfigsRecord({}) + state_dict = ctxt.state.configs_records[RECORD_KEY_STATE] + state = SecAggPlusState(**state_dict) + + # Retrieve incoming configs + configs = msg.content.configs_records[RECORD_KEY_CONFIGS] + + # Check the validity of the next stage + check_stage(state.current_stage, configs) + + # Update the current stage + state.current_stage = cast(str, configs.pop(Key.STAGE)) + + # Check the validity of the configs based on the current stage + check_configs(state.current_stage, configs) + + # Execute + out_content = RecordSet() + if state.current_stage == Stage.SETUP: + state.nid = msg.metadata.dst_node_id + res = _setup(state, configs) + elif state.current_stage == Stage.SHARE_KEYS: + res = _share_keys(state, configs) + elif state.current_stage == Stage.COLLECT_MASKED_VECTORS: + out_msg = call_next(msg, ctxt) + out_content = out_msg.content + fitres = compat.recordset_to_fitres(out_content, keep_input=True) + res = _collect_masked_vectors( + state, configs, fitres.num_examples, fitres.parameters + ) + for p_record in out_content.parameters_records.values(): + p_record.clear() + elif state.current_stage == Stage.UNMASK: + res = _unmask(state, configs) + else: + raise ValueError(f"Unknown SecAgg/SecAgg+ stage: {state.current_stage}") + + # Save state + ctxt.state.configs_records[RECORD_KEY_STATE] = ConfigsRecord(state.to_dict()) + + # Return message + out_content.configs_records[RECORD_KEY_CONFIGS] = ConfigsRecord(res, False) + return msg.create_reply(out_content, ttl="") + + +def check_stage(current_stage: str, configs: ConfigsRecord) -> None: + """Check the validity of the next stage.""" + # Check the existence of Config.STAGE + if Key.STAGE not in configs: + raise KeyError( + f"The required key '{Key.STAGE}' is missing from the ConfigsRecord." + ) + + # Check the value type of the Config.STAGE + next_stage = configs[Key.STAGE] + if not isinstance(next_stage, str): + raise TypeError( + f"The value for the key '{Key.STAGE}' must be of type {str}, " + f"but got {type(next_stage)} instead." + ) + + # Check the validity of the next stage + if next_stage == Stage.SETUP: + if current_stage != Stage.UNMASK: + log(WARNING, "Restart from the setup stage") + # If stage is not "setup", + # the stage from configs should be the expected next stage + else: + stages = Stage.all() + expected_next_stage = stages[(stages.index(current_stage) + 1) % len(stages)] + if next_stage != expected_next_stage: + raise ValueError( + "Abort secure aggregation: " + f"expect {expected_next_stage} stage, but receive {next_stage} stage" + ) + + +# pylint: disable-next=too-many-branches +def check_configs(stage: str, configs: ConfigsRecord) -> None: + """Check the validity of the configs.""" + # Check configs for the setup stage + if stage == Stage.SETUP: + key_type_pairs = [ + (Key.SAMPLE_NUMBER, int), + (Key.SHARE_NUMBER, int), + (Key.THRESHOLD, int), + (Key.CLIPPING_RANGE, float), + (Key.TARGET_RANGE, int), + (Key.MOD_RANGE, int), + ] + for key, expected_type in key_type_pairs: + if key not in configs: + raise KeyError( + f"Stage {Stage.SETUP}: the required key '{key}' is " + "missing from the ConfigsRecord." + ) + # Bool is a subclass of int in Python, + # so `isinstance(v, int)` will return True even if v is a boolean. + # pylint: disable-next=unidiomatic-typecheck + if type(configs[key]) is not expected_type: + raise TypeError( + f"Stage {Stage.SETUP}: The value for the key '{key}' " + f"must be of type {expected_type}, " + f"but got {type(configs[key])} instead." + ) + elif stage == Stage.SHARE_KEYS: + for key, value in configs.items(): + if ( + not isinstance(value, list) + or len(value) != 2 + or not isinstance(value[0], bytes) + or not isinstance(value[1], bytes) + ): + raise TypeError( + f"Stage {Stage.SHARE_KEYS}: " + f"the value for the key '{key}' must be a list of two bytes." + ) + elif stage == Stage.COLLECT_MASKED_VECTORS: + key_type_pairs = [ + (Key.CIPHERTEXT_LIST, bytes), + (Key.SOURCE_LIST, int), + ] + for key, expected_type in key_type_pairs: + if key not in configs: + raise KeyError( + f"Stage {Stage.COLLECT_MASKED_VECTORS}: " + f"the required key '{key}' is " + "missing from the ConfigsRecord." + ) + if not isinstance(configs[key], list) or any( + elm + for elm in cast(List[Any], configs[key]) + # pylint: disable-next=unidiomatic-typecheck + if type(elm) is not expected_type + ): + raise TypeError( + f"Stage {Stage.COLLECT_MASKED_VECTORS}: " + f"the value for the key '{key}' " + f"must be of type List[{expected_type.__name__}]" + ) + elif stage == Stage.UNMASK: + key_type_pairs = [ + (Key.ACTIVE_NODE_ID_LIST, int), + (Key.DEAD_NODE_ID_LIST, int), + ] + for key, expected_type in key_type_pairs: + if key not in configs: + raise KeyError( + f"Stage {Stage.UNMASK}: " + f"the required key '{key}' is " + "missing from the ConfigsRecord." + ) + if not isinstance(configs[key], list) or any( + elm + for elm in cast(List[Any], configs[key]) + # pylint: disable-next=unidiomatic-typecheck + if type(elm) is not expected_type + ): + raise TypeError( + f"Stage {Stage.UNMASK}: " + f"the value for the key '{key}' " + f"must be of type List[{expected_type.__name__}]" + ) + else: + raise ValueError(f"Unknown secagg stage: {stage}") + + +def _setup( + state: SecAggPlusState, configs: ConfigsRecord +) -> Dict[str, ConfigsRecordValues]: + # Assigning parameter values to object fields + sec_agg_param_dict = configs + state.sample_num = cast(int, sec_agg_param_dict[Key.SAMPLE_NUMBER]) + log(DEBUG, "Node %d: starting stage 0...", state.nid) + + state.share_num = cast(int, sec_agg_param_dict[Key.SHARE_NUMBER]) + state.threshold = cast(int, sec_agg_param_dict[Key.THRESHOLD]) + state.clipping_range = cast(float, sec_agg_param_dict[Key.CLIPPING_RANGE]) + state.target_range = cast(int, sec_agg_param_dict[Key.TARGET_RANGE]) + state.mod_range = cast(int, sec_agg_param_dict[Key.MOD_RANGE]) + state.max_weight = cast(float, sec_agg_param_dict[Key.MAX_WEIGHT]) + + # Dictionaries containing node IDs as keys + # and their respective secret shares as values. + state.rd_seed_share_dict = {} + state.sk1_share_dict = {} + # Dictionary containing node IDs as keys + # and their respective shared secrets (with this client) as values. + state.ss2_dict = {} + + # Create 2 sets private public key pairs + # One for creating pairwise masks + # One for encrypting message to distribute shares + sk1, pk1 = generate_key_pairs() + sk2, pk2 = generate_key_pairs() + + state.sk1, state.pk1 = private_key_to_bytes(sk1), public_key_to_bytes(pk1) + state.sk2, state.pk2 = private_key_to_bytes(sk2), public_key_to_bytes(pk2) + log(DEBUG, "Node %d: stage 0 completes. uploading public keys...", state.nid) + return {Key.PUBLIC_KEY_1: state.pk1, Key.PUBLIC_KEY_2: state.pk2} + + +# pylint: disable-next=too-many-locals +def _share_keys( + state: SecAggPlusState, configs: ConfigsRecord +) -> Dict[str, ConfigsRecordValues]: + named_bytes_tuples = cast(Dict[str, Tuple[bytes, bytes]], configs) + key_dict = {int(sid): (pk1, pk2) for sid, (pk1, pk2) in named_bytes_tuples.items()} + log(DEBUG, "Node %d: starting stage 1...", state.nid) + state.public_keys_dict = key_dict + + # Check if the size is larger than threshold + if len(state.public_keys_dict) < state.threshold: + raise ValueError("Available neighbours number smaller than threshold") + + # Check if all public keys are unique + pk_list: List[bytes] = [] + for pk1, pk2 in state.public_keys_dict.values(): + pk_list.append(pk1) + pk_list.append(pk2) + if len(set(pk_list)) != len(pk_list): + raise ValueError("Some public keys are identical") + + # Check if public keys of this client are correct in the dictionary + if ( + state.public_keys_dict[state.nid][0] != state.pk1 + or state.public_keys_dict[state.nid][1] != state.pk2 + ): + raise ValueError( + "Own public keys are displayed in dict incorrectly, should not happen!" + ) + + # Generate the private mask seed + state.rd_seed = os.urandom(32) + + # Create shares for the private mask seed and the first private key + b_shares = create_shares(state.rd_seed, state.threshold, state.share_num) + sk1_shares = create_shares(state.sk1, state.threshold, state.share_num) + + srcs, dsts, ciphertexts = [], [], [] + + # Distribute shares + for idx, (nid, (_, pk2)) in enumerate(state.public_keys_dict.items()): + if nid == state.nid: + state.rd_seed_share_dict[state.nid] = b_shares[idx] + state.sk1_share_dict[state.nid] = sk1_shares[idx] + else: + shared_key = generate_shared_key( + bytes_to_private_key(state.sk2), + bytes_to_public_key(pk2), + ) + state.ss2_dict[nid] = shared_key + plaintext = share_keys_plaintext_concat( + state.nid, nid, b_shares[idx], sk1_shares[idx] + ) + ciphertext = encrypt(shared_key, plaintext) + srcs.append(state.nid) + dsts.append(nid) + ciphertexts.append(ciphertext) + + log(DEBUG, "Node %d: stage 1 completes. uploading key shares...", state.nid) + return {Key.DESTINATION_LIST: dsts, Key.CIPHERTEXT_LIST: ciphertexts} + + +# pylint: disable-next=too-many-locals +def _collect_masked_vectors( + state: SecAggPlusState, + configs: ConfigsRecord, + num_examples: int, + updated_parameters: Parameters, +) -> Dict[str, ConfigsRecordValues]: + log(DEBUG, "Node %d: starting stage 2...", state.nid) + available_clients: List[int] = [] + ciphertexts = cast(List[bytes], configs[Key.CIPHERTEXT_LIST]) + srcs = cast(List[int], configs[Key.SOURCE_LIST]) + if len(ciphertexts) + 1 < state.threshold: + raise ValueError("Not enough available neighbour clients.") + + # Decrypt ciphertexts, verify their sources, and store shares. + for src, ciphertext in zip(srcs, ciphertexts): + shared_key = state.ss2_dict[src] + plaintext = decrypt(shared_key, ciphertext) + actual_src, dst, rd_seed_share, sk1_share = share_keys_plaintext_separate( + plaintext + ) + available_clients.append(src) + if src != actual_src: + raise ValueError( + f"Node {state.nid}: received ciphertext " + f"from {actual_src} instead of {src}." + ) + if dst != state.nid: + raise ValueError( + f"Node {state.nid}: received an encrypted message" + f"for Node {dst} from Node {src}." + ) + state.rd_seed_share_dict[src] = rd_seed_share + state.sk1_share_dict[src] = sk1_share + + # Fit + ratio = num_examples / state.max_weight + if ratio > 1: + log( + WARNING, + "Potential overflow warning: the provided weight (%s) exceeds the specified" + " max_weight (%s). This may lead to overflow issues.", + num_examples, + state.max_weight, + ) + q_ratio = round(ratio * state.target_range) + dq_ratio = q_ratio / state.target_range + + parameters = parameters_to_ndarrays(updated_parameters) + parameters = parameters_multiply(parameters, dq_ratio) + + # Quantize parameter update (vector) + quantized_parameters = quantize( + parameters, state.clipping_range, state.target_range + ) + + quantized_parameters = factor_combine(q_ratio, quantized_parameters) + + dimensions_list: List[Tuple[int, ...]] = [a.shape for a in quantized_parameters] + + # Add private mask + private_mask = pseudo_rand_gen(state.rd_seed, state.mod_range, dimensions_list) + quantized_parameters = parameters_addition(quantized_parameters, private_mask) + + for node_id in available_clients: + # Add pairwise masks + shared_key = generate_shared_key( + bytes_to_private_key(state.sk1), + bytes_to_public_key(state.public_keys_dict[node_id][0]), + ) + pairwise_mask = pseudo_rand_gen(shared_key, state.mod_range, dimensions_list) + if state.nid > node_id: + quantized_parameters = parameters_addition( + quantized_parameters, pairwise_mask + ) + else: + quantized_parameters = parameters_subtraction( + quantized_parameters, pairwise_mask + ) + + # Take mod of final weight update vector and return to server + quantized_parameters = parameters_mod(quantized_parameters, state.mod_range) + log(DEBUG, "Node %d: stage 2 completed, uploading masked parameters...", state.nid) + return { + Key.MASKED_PARAMETERS: [ndarray_to_bytes(arr) for arr in quantized_parameters] + } + + +def _unmask( + state: SecAggPlusState, configs: ConfigsRecord +) -> Dict[str, ConfigsRecordValues]: + log(DEBUG, "Node %d: starting stage 3...", state.nid) + + active_nids = cast(List[int], configs[Key.ACTIVE_NODE_ID_LIST]) + dead_nids = cast(List[int], configs[Key.DEAD_NODE_ID_LIST]) + # Send private mask seed share for every avaliable client (including itself) + # Send first private key share for building pairwise mask for every dropped client + if len(active_nids) < state.threshold: + raise ValueError("Available neighbours number smaller than threshold") + + all_nids, shares = [], [] + all_nids = active_nids + dead_nids + shares += [state.rd_seed_share_dict[nid] for nid in active_nids] + shares += [state.sk1_share_dict[nid] for nid in dead_nids] + + log(DEBUG, "Node %d: stage 3 completes. uploading key shares...", state.nid) + return {Key.NODE_ID_LIST: all_nids, Key.SHARE_LIST: shares} diff --git a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py new file mode 100644 index 000000000000..db5ed67c02a4 --- /dev/null +++ b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py @@ -0,0 +1,316 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The SecAgg+ protocol modifier tests.""" + +import unittest +from itertools import product +from typing import Callable, Dict, List + +from flwr.client.mod import make_ffn +from flwr.common import ConfigsRecord, Context, Message, Metadata, RecordSet +from flwr.common.constant import MessageType +from flwr.common.secure_aggregation.secaggplus_constants import ( + RECORD_KEY_CONFIGS, + RECORD_KEY_STATE, + Key, + Stage, +) +from flwr.common.typing import ConfigsRecordValues + +from .secaggplus_mod import SecAggPlusState, check_configs, secaggplus_mod + + +def get_test_handler( + ctxt: Context, +) -> Callable[[Dict[str, ConfigsRecordValues]], ConfigsRecord]: + """.""" + + def empty_ffn(_msg: Message, _2: Context) -> Message: + return _msg.create_reply(RecordSet(), ttl="") + + app = make_ffn(empty_ffn, [secaggplus_mod]) + + def func(configs: Dict[str, ConfigsRecordValues]) -> ConfigsRecord: + in_msg = Message( + metadata=Metadata( + run_id=0, + message_id="", + src_node_id=0, + dst_node_id=123, + reply_to_message="", + group_id="", + ttl="", + message_type=MessageType.TRAIN, + ), + content=RecordSet( + configs_records={RECORD_KEY_CONFIGS: ConfigsRecord(configs)} + ), + ) + out_msg = app(in_msg, ctxt) + return out_msg.content.configs_records[RECORD_KEY_CONFIGS] + + return func + + +def _make_ctxt() -> Context: + cfg = ConfigsRecord(SecAggPlusState().to_dict()) + return Context(RecordSet(configs_records={RECORD_KEY_STATE: cfg})) + + +def _make_set_state_fn( + ctxt: Context, +) -> Callable[[str], None]: + def set_stage(stage: str) -> None: + state_dict = ctxt.state.configs_records[RECORD_KEY_STATE] + state = SecAggPlusState(**state_dict) + state.current_stage = stage + ctxt.state.configs_records[RECORD_KEY_STATE] = ConfigsRecord(state.to_dict()) + + return set_stage + + +class TestSecAggPlusHandler(unittest.TestCase): + """Test the SecAgg+ protocol handler.""" + + def test_stage_transition(self) -> None: + """Test stage transition.""" + ctxt = _make_ctxt() + handler = get_test_handler(ctxt) + set_stage = _make_set_state_fn(ctxt) + + assert Stage.all() == ( + Stage.SETUP, + Stage.SHARE_KEYS, + Stage.COLLECT_MASKED_VECTORS, + Stage.UNMASK, + ) + + valid_transitions = { + # From one stage to the next stage + (Stage.UNMASK, Stage.SETUP), + (Stage.SETUP, Stage.SHARE_KEYS), + (Stage.SHARE_KEYS, Stage.COLLECT_MASKED_VECTORS), + (Stage.COLLECT_MASKED_VECTORS, Stage.UNMASK), + # From any stage to the initial stage + # Such transitions will log a warning. + (Stage.SETUP, Stage.SETUP), + (Stage.SHARE_KEYS, Stage.SETUP), + (Stage.COLLECT_MASKED_VECTORS, Stage.SETUP), + } + + invalid_transitions = set(product(Stage.all(), Stage.all())).difference( + valid_transitions + ) + + # Test valid transitions + # If the next stage is valid, the function should update the current stage + # and then raise KeyError or other exceptions when trying to execute SA. + for current_stage, next_stage in valid_transitions: + set_stage(current_stage) + + with self.assertRaises(KeyError): + handler({Key.STAGE: next_stage}) + + # Test invalid transitions + # If the next stage is invalid, the function should raise ValueError + for current_stage, next_stage in invalid_transitions: + set_stage(current_stage) + + with self.assertRaises(ValueError): + handler({Key.STAGE: next_stage}) + + def test_stage_setup_check(self) -> None: + """Test content checking for the setup stage.""" + ctxt = _make_ctxt() + handler = get_test_handler(ctxt) + set_stage = _make_set_state_fn(ctxt) + + valid_key_type_pairs = [ + (Key.SAMPLE_NUMBER, int), + (Key.SHARE_NUMBER, int), + (Key.THRESHOLD, int), + (Key.CLIPPING_RANGE, float), + (Key.TARGET_RANGE, int), + (Key.MOD_RANGE, int), + ] + + type_to_test_value: Dict[type, ConfigsRecordValues] = { + int: 10, + bool: True, + float: 1.0, + str: "test", + bytes: b"test", + } + + valid_configs: Dict[str, ConfigsRecordValues] = { + key: type_to_test_value[value_type] + for key, value_type in valid_key_type_pairs + } + + # Test valid configs + try: + check_configs(Stage.SETUP, ConfigsRecord(valid_configs)) + # pylint: disable-next=broad-except + except Exception as exc: + self.fail(f"check_configs() raised {type(exc)} unexpectedly!") + + # Set the stage + valid_configs[Key.STAGE] = Stage.SETUP + + # Test invalid configs + for key, value_type in valid_key_type_pairs: + invalid_configs = valid_configs.copy() + + # Test wrong value type for the key + for other_type, other_value in type_to_test_value.items(): + if other_type == value_type: + continue + invalid_configs[key] = other_value + + set_stage(Stage.UNMASK) + with self.assertRaises(TypeError): + handler(invalid_configs.copy()) + + # Test missing key + invalid_configs.pop(key) + + set_stage(Stage.UNMASK) + with self.assertRaises(KeyError): + handler(invalid_configs.copy()) + + def test_stage_share_keys_check(self) -> None: + """Test content checking for the share keys stage.""" + ctxt = _make_ctxt() + handler = get_test_handler(ctxt) + set_stage = _make_set_state_fn(ctxt) + + valid_configs: Dict[str, ConfigsRecordValues] = { + "1": [b"public key 1", b"public key 2"], + "2": [b"public key 1", b"public key 2"], + "3": [b"public key 1", b"public key 2"], + } + + # Test valid configs + try: + check_configs(Stage.SHARE_KEYS, ConfigsRecord(valid_configs)) + # pylint: disable-next=broad-except + except Exception as exc: + self.fail(f"check_configs() raised {type(exc)} unexpectedly!") + + # Set the stage + valid_configs[Key.STAGE] = Stage.SHARE_KEYS + + # Test invalid configs + invalid_values: List[ConfigsRecordValues] = [ + b"public key 1", + [b"public key 1"], + [b"public key 1", b"public key 2", b"public key 3"], + ] + + for value in invalid_values: + invalid_configs = valid_configs.copy() + invalid_configs["1"] = value + + set_stage(Stage.SETUP) + with self.assertRaises(TypeError): + handler(invalid_configs.copy()) + + def test_stage_collect_masked_vectors_check(self) -> None: + """Test content checking for the collect masked vectors stage.""" + ctxt = _make_ctxt() + handler = get_test_handler(ctxt) + set_stage = _make_set_state_fn(ctxt) + + valid_configs: Dict[str, ConfigsRecordValues] = { + Key.CIPHERTEXT_LIST: [b"ctxt!", b"ctxt@", b"ctxt#", b"ctxt?"], + Key.SOURCE_LIST: [32, 51324, 32324123, -3], + } + + # Test valid configs + try: + check_configs(Stage.COLLECT_MASKED_VECTORS, ConfigsRecord(valid_configs)) + # pylint: disable-next=broad-except + except Exception as exc: + self.fail(f"check_configs() raised {type(exc)} unexpectedly!") + + # Set the stage + valid_configs[Key.STAGE] = Stage.COLLECT_MASKED_VECTORS + + # Test invalid configs + # Test missing keys + for key in list(valid_configs.keys()): + if key == Key.STAGE: + continue + invalid_configs = valid_configs.copy() + invalid_configs.pop(key) + + set_stage(Stage.SHARE_KEYS) + with self.assertRaises(KeyError): + handler(invalid_configs) + + # Test wrong value type for the key + for key in valid_configs: + if key == Key.STAGE: + continue + invalid_configs = valid_configs.copy() + invalid_configs[key] = [3.1415926] + + set_stage(Stage.SHARE_KEYS) + with self.assertRaises(TypeError): + handler(invalid_configs) + + def test_stage_unmask_check(self) -> None: + """Test content checking for the unmasking stage.""" + ctxt = _make_ctxt() + handler = get_test_handler(ctxt) + set_stage = _make_set_state_fn(ctxt) + + valid_configs: Dict[str, ConfigsRecordValues] = { + Key.ACTIVE_NODE_ID_LIST: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + Key.DEAD_NODE_ID_LIST: [32, 51324, 32324123, -3], + } + + # Test valid configs + try: + check_configs(Stage.UNMASK, ConfigsRecord(valid_configs)) + # pylint: disable-next=broad-except + except Exception as exc: + self.fail(f"check_configs() raised {type(exc)} unexpectedly!") + + # Set the stage + valid_configs[Key.STAGE] = Stage.UNMASK + + # Test invalid configs + # Test missing keys + for key in list(valid_configs.keys()): + if key == Key.STAGE: + continue + invalid_configs = valid_configs.copy() + invalid_configs.pop(key) + + set_stage(Stage.COLLECT_MASKED_VECTORS) + with self.assertRaises(KeyError): + handler(invalid_configs) + + # Test wrong value type for the key + for key in valid_configs: + if key == Key.STAGE: + continue + invalid_configs = valid_configs.copy() + invalid_configs[key] = [True, False, True, False] + + set_stage(Stage.COLLECT_MASKED_VECTORS) + with self.assertRaises(TypeError): + handler(invalid_configs) diff --git a/src/py/flwr/client/mod/utils.py b/src/py/flwr/client/mod/utils.py new file mode 100644 index 000000000000..4c3c32944f01 --- /dev/null +++ b/src/py/flwr/client/mod/utils.py @@ -0,0 +1,36 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility functions for mods.""" + + +from typing import List + +from flwr.client.typing import ClientAppCallable, Mod +from flwr.common import Context, Message + + +def make_ffn(ffn: ClientAppCallable, mods: List[Mod]) -> ClientAppCallable: + """.""" + + def wrap_ffn(_ffn: ClientAppCallable, _mod: Mod) -> ClientAppCallable: + def new_ffn(message: Message, context: Context) -> Message: + return _mod(message, context, _ffn) + + return new_ffn + + for mod in reversed(mods): + ffn = wrap_ffn(ffn, mod) + + return ffn diff --git a/src/py/flwr/client/mod/utils_test.py b/src/py/flwr/client/mod/utils_test.py new file mode 100644 index 000000000000..e588b8b53b3b --- /dev/null +++ b/src/py/flwr/client/mod/utils_test.py @@ -0,0 +1,154 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the utility functions.""" + + +import unittest +from typing import List, cast + +from flwr.client.typing import ClientAppCallable, Mod +from flwr.common import ( + ConfigsRecord, + Context, + Message, + Metadata, + MetricsRecord, + RecordSet, +) + +from .utils import make_ffn + +METRIC = "context" +COUNTER = "counter" + + +def _increment_context_counter(context: Context) -> None: + # Read from context + current_counter = cast(int, context.state.metrics_records[METRIC][COUNTER]) + # update and override context + current_counter += 1 + context.state.metrics_records[METRIC] = MetricsRecord({COUNTER: current_counter}) + + +def make_mock_mod(name: str, footprint: List[str]) -> Mod: + """Make a mock mod.""" + + def mod(message: Message, context: Context, app: ClientAppCallable) -> Message: + footprint.append(name) + # add empty ConfigRecord to in_message for this mod + message.content.configs_records[name] = ConfigsRecord() + _increment_context_counter(context) + out_message: Message = app(message, context) + footprint.append(name) + _increment_context_counter(context) + # add empty ConfigRegcord to out_message for this mod + out_message.content.configs_records[name] = ConfigsRecord() + return out_message + + return mod + + +def make_mock_app(name: str, footprint: List[str]) -> ClientAppCallable: + """Make a mock app.""" + + def app(message: Message, context: Context) -> Message: + footprint.append(name) + message.content.configs_records[name] = ConfigsRecord() + out_message = Message(metadata=message.metadata, content=RecordSet()) + out_message.content.configs_records[name] = ConfigsRecord() + print(context) + return out_message + + return app + + +def _get_dummy_flower_message() -> Message: + return Message( + content=RecordSet(), + metadata=Metadata( + run_id=0, + message_id="", + group_id="", + src_node_id=0, + dst_node_id=0, + reply_to_message="", + ttl="", + message_type="mock", + ), + ) + + +class TestMakeApp(unittest.TestCase): + """Tests for the `make_app` function.""" + + def test_multiple_mods(self) -> None: + """Test if multiple mods are called in the correct order.""" + # Prepare + footprint: List[str] = [] + mock_app = make_mock_app("app", footprint) + mock_mod_names = [f"mod{i}" for i in range(1, 15)] + mock_mods = [make_mock_mod(name, footprint) for name in mock_mod_names] + + state = RecordSet() + state.metrics_records[METRIC] = MetricsRecord({COUNTER: 0.0}) + context = Context(state=state) + message = _get_dummy_flower_message() + + # Execute + wrapped_app = make_ffn(mock_app, mock_mods) + out_message = wrapped_app(message, context) + + # Assert + trace = mock_mod_names + ["app"] + self.assertEqual(footprint, trace + list(reversed(mock_mod_names))) + # pylint: disable-next=no-member + self.assertEqual( + "".join(message.content.configs_records.keys()), "".join(trace) + ) + self.assertEqual( + "".join(out_message.content.configs_records.keys()), + "".join(reversed(trace)), + ) + self.assertEqual(state.metrics_records[METRIC][COUNTER], 2 * len(mock_mods)) + + def test_filter(self) -> None: + """Test if a mod can filter incoming TaskIns.""" + # Prepare + footprint: List[str] = [] + mock_app = make_mock_app("app", footprint) + context = Context(state=RecordSet()) + message = _get_dummy_flower_message() + + def filter_mod( + message: Message, + _1: Context, + _2: ClientAppCallable, + ) -> Message: + footprint.append("filter") + message.content.configs_records["filter"] = ConfigsRecord() + out_message = Message(metadata=message.metadata, content=RecordSet()) + out_message.content.configs_records["filter"] = ConfigsRecord() + # Skip calling app + return out_message + + # Execute + wrapped_app = make_ffn(mock_app, [filter_mod]) + out_message = wrapped_app(message, context) + + # Assert + self.assertEqual(footprint, ["filter"]) + # pylint: disable-next=no-member + self.assertEqual(list(message.content.configs_records.keys())[0], "filter") + self.assertEqual(list(out_message.content.configs_records.keys())[0], "filter") diff --git a/src/py/flwr/client/node_state.py b/src/py/flwr/client/node_state.py index ee4f70dc4dca..71681b783419 100644 --- a/src/py/flwr/client/node_state.py +++ b/src/py/flwr/client/node_state.py @@ -17,34 +17,32 @@ from typing import Any, Dict -from flwr.client.workload_state import WorkloadState +from flwr.common import Context, RecordSet class NodeState: - """State of a node where client nodes execute workloads.""" + """State of a node where client nodes execute runs.""" def __init__(self) -> None: self._meta: Dict[str, Any] = {} # holds metadata about the node - self.workload_states: Dict[int, WorkloadState] = {} + self.run_contexts: Dict[int, Context] = {} - def register_workloadstate(self, workload_id: int) -> None: - """Register new workload state for this node.""" - if workload_id not in self.workload_states: - self.workload_states[workload_id] = WorkloadState({}) + def register_context(self, run_id: int) -> None: + """Register new run context for this node.""" + if run_id not in self.run_contexts: + self.run_contexts[run_id] = Context(state=RecordSet()) - def retrieve_workloadstate(self, workload_id: int) -> WorkloadState: - """Get workload state given a workload_id.""" - if workload_id in self.workload_states: - return self.workload_states[workload_id] + def retrieve_context(self, run_id: int) -> Context: + """Get run context given a run_id.""" + if run_id in self.run_contexts: + return self.run_contexts[run_id] raise RuntimeError( - f"WorkloadState for workload_id={workload_id} doesn't exist." - " A workload must be registered before it can be retrieved or updated " + f"Context for run_id={run_id} doesn't exist." + " A run context must be registered before it can be retrieved or updated " " by a client." ) - def update_workloadstate( - self, workload_id: int, workload_state: WorkloadState - ) -> None: - """Update workload state.""" - self.workload_states[workload_id] = workload_state + def update_context(self, run_id: int, context: Context) -> None: + """Update run context.""" + self.run_contexts[run_id] = context diff --git a/src/py/flwr/client/node_state_tests.py b/src/py/flwr/client/node_state_tests.py index d9f9ae7db3b0..193f52661579 100644 --- a/src/py/flwr/client/node_state_tests.py +++ b/src/py/flwr/client/node_state_tests.py @@ -15,45 +15,51 @@ """Node state tests.""" +from typing import cast + from flwr.client.node_state import NodeState -from flwr.client.workload_state import WorkloadState -from flwr.proto.task_pb2 import TaskIns +from flwr.common import ConfigsRecord, Context +from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611 + +def _run_dummy_task(context: Context) -> Context: + counter_value: str = "1" + if "counter" in context.state.configs_records.keys(): + counter_value = cast(str, context.state.configs_records["counter"]["count"]) + counter_value += "1" -def _run_dummy_task(state: WorkloadState) -> WorkloadState: - if "counter" in state.state: - state.state["counter"] += "1" - else: - state.state["counter"] = "1" + context.state.configs_records["counter"] = ConfigsRecord({"count": counter_value}) - return state + return context -def test_multiworkload_in_node_state() -> None: +def test_multirun_in_node_state() -> None: """Test basic NodeState logic.""" # Tasks to perform - tasks = [TaskIns(workload_id=w_id) for w_id in [0, 1, 1, 2, 3, 2, 1, 5]] - # the "tasks" is to count how many times each workload is executed + tasks = [TaskIns(run_id=run_id) for run_id in [0, 1, 1, 2, 3, 2, 1, 5]] + # the "tasks" is to count how many times each run is executed expected_values = {0: "1", 1: "1" * 3, 2: "1" * 2, 3: "1", 5: "1"} # NodeState node_state = NodeState() for task in tasks: - w_id = task.workload_id + run_id = task.run_id # Register - node_state.register_workloadstate(workload_id=w_id) + node_state.register_context(run_id=run_id) - # Get workload state - state = node_state.retrieve_workloadstate(workload_id=w_id) + # Get run state + context = node_state.retrieve_context(run_id=run_id) # Run "task" - updated_state = _run_dummy_task(state) + updated_state = _run_dummy_task(context) - # Update workload state - node_state.update_workloadstate(workload_id=w_id, workload_state=updated_state) + # Update run state + node_state.update_context(run_id=run_id, context=updated_state) # Verify values - for w_id, state in node_state.workload_states.items(): - assert state.state["counter"] == expected_values[w_id] + for run_id, context in node_state.run_contexts.items(): + assert ( + context.state.configs_records["counter"]["count"] == expected_values[run_id] + ) diff --git a/src/py/flwr/client/numpy_client.py b/src/py/flwr/client/numpy_client.py index 8b0893ea30aa..0247958d88a9 100644 --- a/src/py/flwr/client/numpy_client.py +++ b/src/py/flwr/client/numpy_client.py @@ -19,9 +19,9 @@ from typing import Callable, Dict, Tuple from flwr.client.client import Client -from flwr.client.workload_state import WorkloadState from flwr.common import ( Config, + Context, NDArrays, Scalar, ndarrays_to_parameters, @@ -70,7 +70,7 @@ class NumPyClient(ABC): """Abstract base class for Flower clients using NumPy.""" - state: WorkloadState + context: Context def get_properties(self, config: Config) -> Dict[str, Scalar]: """Return a client's set of properties. @@ -174,13 +174,13 @@ def evaluate( _ = (self, parameters, config) return 0.0, 0, {} - def get_state(self) -> WorkloadState: - """Get the workload state from this client.""" - return self.state + def get_context(self) -> Context: + """Get the run context from this client.""" + return self.context - def set_state(self, state: WorkloadState) -> None: - """Apply a workload state to this client.""" - self.state = state + def set_context(self, context: Context) -> None: + """Apply a run context to this client.""" + self.context = context def to_client(self) -> Client: """Convert to object to Client type and return it.""" @@ -242,7 +242,7 @@ def _fit(self: Client, ins: FitIns) -> FitRes: and isinstance(results[1], int) and isinstance(results[2], dict) ): - raise Exception(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_FIT) + raise TypeError(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_FIT) # Return FitRes parameters_prime, num_examples, metrics = results @@ -266,7 +266,7 @@ def _evaluate(self: Client, ins: EvaluateIns) -> EvaluateRes: and isinstance(results[1], int) and isinstance(results[2], dict) ): - raise Exception(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_EVALUATE) + raise TypeError(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_EVALUATE) # Return EvaluateRes loss, num_examples, metrics = results @@ -278,21 +278,21 @@ def _evaluate(self: Client, ins: EvaluateIns) -> EvaluateRes: ) -def _get_state(self: Client) -> WorkloadState: - """Return state of underlying NumPyClient.""" - return self.numpy_client.get_state() # type: ignore +def _get_context(self: Client) -> Context: + """Return context of underlying NumPyClient.""" + return self.numpy_client.get_context() # type: ignore -def _set_state(self: Client, state: WorkloadState) -> None: - """Apply state to underlying NumPyClient.""" - self.numpy_client.set_state(state) # type: ignore +def _set_context(self: Client, context: Context) -> None: + """Apply context to underlying NumPyClient.""" + self.numpy_client.set_context(context) # type: ignore def _wrap_numpy_client(client: NumPyClient) -> Client: member_dict: Dict[str, Callable] = { # type: ignore "__init__": _constructor, - "get_state": _get_state, - "set_state": _set_state, + "get_context": _get_context, + "set_context": _set_context, } # Add wrapper type methods (if overridden) diff --git a/src/py/flwr/client/rest_client/connection.py b/src/py/flwr/client/rest_client/connection.py index d22b246dbd61..d2cc71ba3b3f 100644 --- a/src/py/flwr/client/rest_client/connection.py +++ b/src/py/flwr/client/rest_client/connection.py @@ -17,19 +17,19 @@ import sys from contextlib import contextmanager +from copy import copy from logging import ERROR, INFO, WARN from typing import Callable, Dict, Iterator, Optional, Tuple, Union, cast -from flwr.client.message_handler.task_handler import ( - configure_task_res, - get_task_ins, - validate_task_ins, - validate_task_res, -) +from flwr.client.message_handler.message_handler import validate_out_message +from flwr.client.message_handler.task_handler import get_task_ins, validate_task_ins from flwr.common import GRPC_MAX_MESSAGE_LENGTH from flwr.common.constant import MISSING_EXTRA_REST from flwr.common.logger import log -from flwr.proto.fleet_pb2 import ( +from flwr.common.message import Message, Metadata +from flwr.common.retry_invoker import RetryInvoker +from flwr.common.serde import message_from_taskins, message_to_taskres +from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, CreateNodeResponse, DeleteNodeRequest, @@ -38,8 +38,8 @@ PushTaskResRequest, PushTaskResResponse, ) -from flwr.proto.node_pb2 import Node -from flwr.proto.task_pb2 import TaskIns, TaskRes +from flwr.proto.node_pb2 import Node # pylint: disable=E0611 +from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611 try: import requests @@ -48,7 +48,7 @@ KEY_NODE = "node" -KEY_TASK_INS = "current_task_ins" +KEY_METADATA = "in_message_metadata" PATH_CREATE_NODE: str = "api/v0/fleet/create-node" @@ -62,14 +62,15 @@ def http_request_response( server_address: str, insecure: bool, # pylint: disable=unused-argument + retry_invoker: RetryInvoker, max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613 root_certificates: Optional[ Union[bytes, str] ] = None, # pylint: disable=unused-argument ) -> Iterator[ Tuple[ - Callable[[], Optional[TaskIns]], - Callable[[TaskRes], None], + Callable[[], Optional[Message]], + Callable[[Message], None], Optional[Callable[[], None]], Optional[Callable[[], None]], ] @@ -85,6 +86,12 @@ def http_request_response( The IPv6 address of the server with `http://` or `https://`. If the Flower server runs on the same machine on port 8080, then `server_address` would be `"http://[::]:8080"`. + insecure : bool + Unused argument present for compatibilty. + retry_invoker: RetryInvoker + `RetryInvoker` object that will try to reconnect the client to the server + after REST connection errors. If None, the client will only try to + reconnect once after a failure. max_message_length : int Ignored, only present to preserve API-compatibility. root_certificates : Optional[Union[bytes, str]] (default: None) @@ -120,8 +127,8 @@ def http_request_response( "must be provided as a string path to the client.", ) - # Necessary state to link TaskRes to TaskIns - state: Dict[str, Optional[TaskIns]] = {KEY_TASK_INS: None} + # Necessary state to validate messages to be sent + state: Dict[str, Optional[Metadata]] = {KEY_METADATA: None} # Enable create_node and delete_node to store node node_store: Dict[str, Optional[Node]] = {KEY_NODE: None} @@ -135,7 +142,8 @@ def create_node() -> None: create_node_req_proto = CreateNodeRequest() create_node_req_bytes: bytes = create_node_req_proto.SerializeToString() - res = requests.post( + res = retry_invoker.invoke( + requests.post, url=f"{base_url}/{PATH_CREATE_NODE}", headers={ "Accept": "application/protobuf", @@ -143,6 +151,7 @@ def create_node() -> None: }, data=create_node_req_bytes, verify=verify, + timeout=None, ) # Check status code and headers @@ -177,7 +186,8 @@ def delete_node() -> None: node: Node = cast(Node, node_store[KEY_NODE]) delete_node_req_proto = DeleteNodeRequest(node=node) delete_node_req_req_bytes: bytes = delete_node_req_proto.SerializeToString() - res = requests.post( + res = retry_invoker.invoke( + requests.post, url=f"{base_url}/{PATH_DELETE_NODE}", headers={ "Accept": "application/protobuf", @@ -185,6 +195,7 @@ def delete_node() -> None: }, data=delete_node_req_req_bytes, verify=verify, + timeout=None, ) # Check status code and headers @@ -204,7 +215,7 @@ def delete_node() -> None: PATH_PULL_TASK_INS, ) - def receive() -> Optional[TaskIns]: + def receive() -> Optional[Message]: """Receive next task from server.""" # Get Node if node_store[KEY_NODE] is None: @@ -217,7 +228,8 @@ def receive() -> Optional[TaskIns]: pull_task_ins_req_bytes: bytes = pull_task_ins_req_proto.SerializeToString() # Request instructions (task) from server - res = requests.post( + res = retry_invoker.invoke( + requests.post, url=f"{base_url}/{PATH_PULL_TASK_INS}", headers={ "Accept": "application/protobuf", @@ -225,6 +237,7 @@ def receive() -> Optional[TaskIns]: }, data=pull_task_ins_req_bytes, verify=verify, + timeout=None, ) # Check status code and headers @@ -253,40 +266,41 @@ def receive() -> Optional[TaskIns]: task_ins: Optional[TaskIns] = get_task_ins(pull_task_ins_response_proto) # Discard the current TaskIns if not valid - if task_ins is not None and not validate_task_ins( - task_ins, discard_reconnect_ins=True + if task_ins is not None and not ( + task_ins.task.consumer.node_id == node.node_id + and validate_task_ins(task_ins) ): task_ins = None - # Remember `task_ins` until `task_res` is available - state[KEY_TASK_INS] = task_ins - - # Return the TaskIns if available + # Return the Message if available + message = None + state[KEY_METADATA] = None if task_ins is not None: + message = message_from_taskins(task_ins) + state[KEY_METADATA] = copy(message.metadata) log(INFO, "[Node] POST /%s: success", PATH_PULL_TASK_INS) - return task_ins + return message - def send(task_res: TaskRes) -> None: + def send(message: Message) -> None: """Send task result back to server.""" # Get Node if node_store[KEY_NODE] is None: log(ERROR, "Node instance missing") return - node: Node = cast(Node, node_store[KEY_NODE]) - if state[KEY_TASK_INS] is None: - log(ERROR, "No current TaskIns") + # Get incoming message + in_metadata = state[KEY_METADATA] + if in_metadata is None: + log(ERROR, "No current message") return - task_ins: TaskIns = cast(TaskIns, state[KEY_TASK_INS]) - - # Check if fields to be set are not initialized - if not validate_task_res(task_res): - state[KEY_TASK_INS] = None - log(ERROR, "TaskRes has been initialized accidentally") + # Validate out message + if not validate_out_message(message, in_metadata): + log(ERROR, "Invalid out message") + return - # Configure TaskRes - task_res = configure_task_res(task_res, task_ins, node) + # Construct TaskRes + task_res = message_to_taskres(message) # Serialize ProtoBuf to bytes push_task_res_request_proto = PushTaskResRequest(task_res_list=[task_res]) @@ -295,7 +309,8 @@ def send(task_res: TaskRes) -> None: ) # Send ClientMessage to server - res = requests.post( + res = retry_invoker.invoke( + requests.post, url=f"{base_url}/{PATH_PUSH_TASK_RES}", headers={ "Accept": "application/protobuf", @@ -303,9 +318,10 @@ def send(task_res: TaskRes) -> None: }, data=push_task_res_request_bytes, verify=verify, + timeout=None, ) - state[KEY_TASK_INS] = None + state[KEY_METADATA] = None # Check status code and headers if res.status_code != 200: diff --git a/src/py/flwr/client/secure_aggregation/handler.py b/src/py/flwr/client/secure_aggregation/handler.py deleted file mode 100644 index 487ed842c93f..000000000000 --- a/src/py/flwr/client/secure_aggregation/handler.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Message Handler for Secure Aggregation (abstract base class).""" - - -from abc import ABC, abstractmethod -from typing import Dict - -from flwr.common.typing import Value - - -class SecureAggregationHandler(ABC): - """Abstract base class for secure aggregation message handlers.""" - - @abstractmethod - def handle_secure_aggregation( - self, named_values: Dict[str, Value] - ) -> Dict[str, Value]: - """Handle incoming Secure Aggregation message and return results. - - Parameters - ---------- - named_values : Dict[str, Value] - The named values retrieved from the SecureAggregation sub-message - of Task message in the server's TaskIns. - - Returns - ------- - Dict[str, Value] - The final/intermediate results of the Secure Aggregation protocol. - """ diff --git a/src/py/flwr/client/secure_aggregation/secaggplus_handler.py b/src/py/flwr/client/secure_aggregation/secaggplus_handler.py deleted file mode 100644 index efbb00a9d916..000000000000 --- a/src/py/flwr/client/secure_aggregation/secaggplus_handler.py +++ /dev/null @@ -1,488 +0,0 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Message handler for the SecAgg+ protocol.""" - - -import os -from dataclasses import dataclass, field -from logging import ERROR, INFO, WARNING -from typing import Any, Dict, List, Optional, Tuple, Union, cast - -from flwr.client.client import Client -from flwr.client.numpy_client import NumPyClient -from flwr.common import ( - bytes_to_ndarray, - ndarray_to_bytes, - ndarrays_to_parameters, - parameters_to_ndarrays, -) -from flwr.common.logger import log -from flwr.common.secure_aggregation.crypto.shamir import create_shares -from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( - bytes_to_private_key, - bytes_to_public_key, - decrypt, - encrypt, - generate_key_pairs, - generate_shared_key, - private_key_to_bytes, - public_key_to_bytes, -) -from flwr.common.secure_aggregation.ndarrays_arithmetic import ( - factor_combine, - parameters_addition, - parameters_mod, - parameters_multiply, - parameters_subtraction, -) -from flwr.common.secure_aggregation.quantization import quantize -from flwr.common.secure_aggregation.secaggplus_constants import ( - KEY_ACTIVE_SECURE_ID_LIST, - KEY_CIPHERTEXT_LIST, - KEY_CLIPPING_RANGE, - KEY_DEAD_SECURE_ID_LIST, - KEY_DESTINATION_LIST, - KEY_MASKED_PARAMETERS, - KEY_MOD_RANGE, - KEY_PARAMETERS, - KEY_PUBLIC_KEY_1, - KEY_PUBLIC_KEY_2, - KEY_SAMPLE_NUMBER, - KEY_SECURE_ID, - KEY_SECURE_ID_LIST, - KEY_SHARE_LIST, - KEY_SHARE_NUMBER, - KEY_SOURCE_LIST, - KEY_STAGE, - KEY_TARGET_RANGE, - KEY_THRESHOLD, - STAGE_COLLECT_MASKED_INPUT, - STAGE_SETUP, - STAGE_SHARE_KEYS, - STAGE_UNMASK, - STAGES, -) -from flwr.common.secure_aggregation.secaggplus_utils import ( - pseudo_rand_gen, - share_keys_plaintext_concat, - share_keys_plaintext_separate, -) -from flwr.common.typing import FitIns, Value - -from .handler import SecureAggregationHandler - - -@dataclass -# pylint: disable-next=too-many-instance-attributes -class SecAggPlusState: - """State of the SecAgg+ protocol.""" - - sid: int = 0 - sample_num: int = 0 - share_num: int = 0 - threshold: int = 0 - clipping_range: float = 0.0 - target_range: int = 0 - mod_range: int = 0 - - # Secret key (sk) and public key (pk) - sk1: bytes = b"" - pk1: bytes = b"" - sk2: bytes = b"" - pk2: bytes = b"" - - # Random seed for generating the private mask - rd_seed: bytes = b"" - - rd_seed_share_dict: Dict[int, bytes] = field(default_factory=dict) - sk1_share_dict: Dict[int, bytes] = field(default_factory=dict) - # The dict of the shared secrets from sk2 - ss2_dict: Dict[int, bytes] = field(default_factory=dict) - public_keys_dict: Dict[int, Tuple[bytes, bytes]] = field(default_factory=dict) - - client: Optional[Union[Client, NumPyClient]] = None - - -class SecAggPlusHandler(SecureAggregationHandler): - """Message handler for the SecAgg+ protocol.""" - - _shared_state = SecAggPlusState() - _current_stage = STAGE_UNMASK - - def handle_secure_aggregation( - self, named_values: Dict[str, Value] - ) -> Dict[str, Value]: - """Handle incoming message and return results, following the SecAgg+ protocol. - - Parameters - ---------- - named_values : Dict[str, Value] - The named values retrieved from the SecureAggregation sub-message - of Task message in the server's TaskIns. - - Returns - ------- - Dict[str, Value] - The final/intermediate results of the SecAgg+ protocol. - """ - # Check if self is a client - if not isinstance(self, (Client, NumPyClient)): - raise TypeError( - "The subclass of SecAggPlusHandler must be " - "the subclass of Client or NumPyClient." - ) - - # Check the validity of the next stage - check_stage(self._current_stage, named_values) - - # Update the current stage - self._current_stage = cast(str, named_values.pop(KEY_STAGE)) - - # Check the validity of the `named_values` based on the current stage - check_named_values(self._current_stage, named_values) - - # Execute - if self._current_stage == STAGE_SETUP: - self._shared_state = SecAggPlusState(client=self) - return _setup(self._shared_state, named_values) - if self._current_stage == STAGE_SHARE_KEYS: - return _share_keys(self._shared_state, named_values) - if self._current_stage == STAGE_COLLECT_MASKED_INPUT: - return _collect_masked_input(self._shared_state, named_values) - if self._current_stage == STAGE_UNMASK: - return _unmask(self._shared_state, named_values) - raise ValueError(f"Unknown secagg stage: {self._current_stage}") - - -def check_stage(current_stage: str, named_values: Dict[str, Value]) -> None: - """Check the validity of the next stage.""" - # Check the existence of KEY_STAGE - if KEY_STAGE not in named_values: - raise KeyError( - f"The required key '{KEY_STAGE}' is missing from the input `named_values`." - ) - - # Check the value type of the KEY_STAGE - next_stage = named_values[KEY_STAGE] - if not isinstance(next_stage, str): - raise TypeError( - f"The value for the key '{KEY_STAGE}' must be of type {str}, " - f"but got {type(next_stage)} instead." - ) - - # Check the validity of the next stage - if next_stage == STAGE_SETUP: - if current_stage != STAGE_UNMASK: - log(WARNING, "Restart from the setup stage") - # If stage is not "setup", - # the stage from `named_values` should be the expected next stage - else: - expected_next_stage = STAGES[(STAGES.index(current_stage) + 1) % len(STAGES)] - if next_stage != expected_next_stage: - raise ValueError( - "Abort secure aggregation: " - f"expect {expected_next_stage} stage, but receive {next_stage} stage" - ) - - -# pylint: disable-next=too-many-branches -def check_named_values(stage: str, named_values: Dict[str, Value]) -> None: - """Check the validity of the input `named_values`.""" - # Check `named_values` for the setup stage - if stage == STAGE_SETUP: - key_type_pairs = [ - (KEY_SAMPLE_NUMBER, int), - (KEY_SECURE_ID, int), - (KEY_SHARE_NUMBER, int), - (KEY_THRESHOLD, int), - (KEY_CLIPPING_RANGE, float), - (KEY_TARGET_RANGE, int), - (KEY_MOD_RANGE, int), - ] - for key, expected_type in key_type_pairs: - if key not in named_values: - raise KeyError( - f"Stage {STAGE_SETUP}: the required key '{key}' is " - "missing from the input `named_values`." - ) - # Bool is a subclass of int in Python, - # so `isinstance(v, int)` will return True even if v is a boolean. - # pylint: disable-next=unidiomatic-typecheck - if type(named_values[key]) is not expected_type: - raise TypeError( - f"Stage {STAGE_SETUP}: The value for the key '{key}' " - f"must be of type {expected_type}, " - f"but got {type(named_values[key])} instead." - ) - elif stage == STAGE_SHARE_KEYS: - for key, value in named_values.items(): - if ( - not isinstance(value, list) - or len(value) != 2 - or not isinstance(value[0], bytes) - or not isinstance(value[1], bytes) - ): - raise TypeError( - f"Stage {STAGE_SHARE_KEYS}: " - f"the value for the key '{key}' must be a list of two bytes." - ) - elif stage == STAGE_COLLECT_MASKED_INPUT: - key_type_pairs = [ - (KEY_CIPHERTEXT_LIST, bytes), - (KEY_SOURCE_LIST, int), - (KEY_PARAMETERS, bytes), - ] - for key, expected_type in key_type_pairs: - if key not in named_values: - raise KeyError( - f"Stage {STAGE_COLLECT_MASKED_INPUT}: " - f"the required key '{key}' is " - "missing from the input `named_values`." - ) - if not isinstance(named_values[key], list) or any( - elm - for elm in cast(List[Any], named_values[key]) - # pylint: disable-next=unidiomatic-typecheck - if type(elm) is not expected_type - ): - raise TypeError( - f"Stage {STAGE_COLLECT_MASKED_INPUT}: " - f"the value for the key '{key}' " - f"must be of type List[{expected_type.__name__}]" - ) - elif stage == STAGE_UNMASK: - key_type_pairs = [ - (KEY_ACTIVE_SECURE_ID_LIST, int), - (KEY_DEAD_SECURE_ID_LIST, int), - ] - for key, expected_type in key_type_pairs: - if key not in named_values: - raise KeyError( - f"Stage {STAGE_UNMASK}: " - f"the required key '{key}' is " - "missing from the input `named_values`." - ) - if not isinstance(named_values[key], list) or any( - elm - for elm in cast(List[Any], named_values[key]) - # pylint: disable-next=unidiomatic-typecheck - if type(elm) is not expected_type - ): - raise TypeError( - f"Stage {STAGE_UNMASK}: " - f"the value for the key '{key}' " - f"must be of type List[{expected_type.__name__}]" - ) - else: - raise ValueError(f"Unknown secagg stage: {stage}") - - -def _setup(state: SecAggPlusState, named_values: Dict[str, Value]) -> Dict[str, Value]: - # Assigning parameter values to object fields - sec_agg_param_dict = named_values - state.sample_num = cast(int, sec_agg_param_dict[KEY_SAMPLE_NUMBER]) - state.sid = cast(int, sec_agg_param_dict[KEY_SECURE_ID]) - log(INFO, "Client %d: starting stage 0...", state.sid) - - state.share_num = cast(int, sec_agg_param_dict[KEY_SHARE_NUMBER]) - state.threshold = cast(int, sec_agg_param_dict[KEY_THRESHOLD]) - state.clipping_range = cast(float, sec_agg_param_dict[KEY_CLIPPING_RANGE]) - state.target_range = cast(int, sec_agg_param_dict[KEY_TARGET_RANGE]) - state.mod_range = cast(int, sec_agg_param_dict[KEY_MOD_RANGE]) - - # Dictionaries containing client secure IDs as keys - # and their respective secret shares as values. - state.rd_seed_share_dict = {} - state.sk1_share_dict = {} - # Dictionary containing client secure IDs as keys - # and their respective shared secrets (with this client) as values. - state.ss2_dict = {} - - # Create 2 sets private public key pairs - # One for creating pairwise masks - # One for encrypting message to distribute shares - sk1, pk1 = generate_key_pairs() - sk2, pk2 = generate_key_pairs() - - state.sk1, state.pk1 = private_key_to_bytes(sk1), public_key_to_bytes(pk1) - state.sk2, state.pk2 = private_key_to_bytes(sk2), public_key_to_bytes(pk2) - log(INFO, "Client %d: stage 0 completes. uploading public keys...", state.sid) - return {KEY_PUBLIC_KEY_1: state.pk1, KEY_PUBLIC_KEY_2: state.pk2} - - -# pylint: disable-next=too-many-locals -def _share_keys( - state: SecAggPlusState, named_values: Dict[str, Value] -) -> Dict[str, Value]: - named_bytes_tuples = cast(Dict[str, Tuple[bytes, bytes]], named_values) - key_dict = {int(sid): (pk1, pk2) for sid, (pk1, pk2) in named_bytes_tuples.items()} - log(INFO, "Client %d: starting stage 1...", state.sid) - state.public_keys_dict = key_dict - - # Check if the size is larger than threshold - if len(state.public_keys_dict) < state.threshold: - raise Exception("Available neighbours number smaller than threshold") - - # Check if all public keys are unique - pk_list: List[bytes] = [] - for pk1, pk2 in state.public_keys_dict.values(): - pk_list.append(pk1) - pk_list.append(pk2) - if len(set(pk_list)) != len(pk_list): - raise Exception("Some public keys are identical") - - # Check if public keys of this client are correct in the dictionary - if ( - state.public_keys_dict[state.sid][0] != state.pk1 - or state.public_keys_dict[state.sid][1] != state.pk2 - ): - raise Exception( - "Own public keys are displayed in dict incorrectly, should not happen!" - ) - - # Generate the private mask seed - state.rd_seed = os.urandom(32) - - # Create shares for the private mask seed and the first private key - b_shares = create_shares(state.rd_seed, state.threshold, state.share_num) - sk1_shares = create_shares(state.sk1, state.threshold, state.share_num) - - srcs, dsts, ciphertexts = [], [], [] - - # Distribute shares - for idx, (sid, (_, pk2)) in enumerate(state.public_keys_dict.items()): - if sid == state.sid: - state.rd_seed_share_dict[state.sid] = b_shares[idx] - state.sk1_share_dict[state.sid] = sk1_shares[idx] - else: - shared_key = generate_shared_key( - bytes_to_private_key(state.sk2), - bytes_to_public_key(pk2), - ) - state.ss2_dict[sid] = shared_key - plaintext = share_keys_plaintext_concat( - state.sid, sid, b_shares[idx], sk1_shares[idx] - ) - ciphertext = encrypt(shared_key, plaintext) - srcs.append(state.sid) - dsts.append(sid) - ciphertexts.append(ciphertext) - - log(INFO, "Client %d: stage 1 completes. uploading key shares...", state.sid) - return {KEY_DESTINATION_LIST: dsts, KEY_CIPHERTEXT_LIST: ciphertexts} - - -# pylint: disable-next=too-many-locals -def _collect_masked_input( - state: SecAggPlusState, named_values: Dict[str, Value] -) -> Dict[str, Value]: - log(INFO, "Client %d: starting stage 2...", state.sid) - available_clients: List[int] = [] - ciphertexts = cast(List[bytes], named_values[KEY_CIPHERTEXT_LIST]) - srcs = cast(List[int], named_values[KEY_SOURCE_LIST]) - if len(ciphertexts) + 1 < state.threshold: - raise Exception("Not enough available neighbour clients.") - - # Decrypt ciphertexts, verify their sources, and store shares. - for src, ciphertext in zip(srcs, ciphertexts): - shared_key = state.ss2_dict[src] - plaintext = decrypt(shared_key, ciphertext) - actual_src, dst, rd_seed_share, sk1_share = share_keys_plaintext_separate( - plaintext - ) - available_clients.append(src) - if src != actual_src: - raise ValueError( - f"Client {state.sid}: received ciphertext " - f"from {actual_src} instead of {src}." - ) - if dst != state.sid: - ValueError( - f"Client {state.sid}: received an encrypted message" - f"for Client {dst} from Client {src}." - ) - state.rd_seed_share_dict[src] = rd_seed_share - state.sk1_share_dict[src] = sk1_share - - # Fit client - parameters_bytes = cast(List[bytes], named_values[KEY_PARAMETERS]) - parameters = [bytes_to_ndarray(w) for w in parameters_bytes] - if isinstance(state.client, Client): - fit_res = state.client.fit( - FitIns(parameters=ndarrays_to_parameters(parameters), config={}) - ) - parameters_factor = fit_res.num_examples - parameters = parameters_to_ndarrays(fit_res.parameters) - elif isinstance(state.client, NumPyClient): - parameters, parameters_factor, _ = state.client.fit(parameters, {}) - else: - log(ERROR, "Client %d: fit function is missing.", state.sid) - - # Quantize parameter update (vector) - quantized_parameters = quantize( - parameters, state.clipping_range, state.target_range - ) - - quantized_parameters = parameters_multiply(quantized_parameters, parameters_factor) - quantized_parameters = factor_combine(parameters_factor, quantized_parameters) - - dimensions_list: List[Tuple[int, ...]] = [a.shape for a in quantized_parameters] - - # Add private mask - private_mask = pseudo_rand_gen(state.rd_seed, state.mod_range, dimensions_list) - quantized_parameters = parameters_addition(quantized_parameters, private_mask) - - for client_id in available_clients: - # Add pairwise masks - shared_key = generate_shared_key( - bytes_to_private_key(state.sk1), - bytes_to_public_key(state.public_keys_dict[client_id][0]), - ) - pairwise_mask = pseudo_rand_gen(shared_key, state.mod_range, dimensions_list) - if state.sid > client_id: - quantized_parameters = parameters_addition( - quantized_parameters, pairwise_mask - ) - else: - quantized_parameters = parameters_subtraction( - quantized_parameters, pairwise_mask - ) - - # Take mod of final weight update vector and return to server - quantized_parameters = parameters_mod(quantized_parameters, state.mod_range) - log(INFO, "Client %d: stage 2 completes. uploading masked parameters...", state.sid) - return { - KEY_MASKED_PARAMETERS: [ndarray_to_bytes(arr) for arr in quantized_parameters] - } - - -def _unmask(state: SecAggPlusState, named_values: Dict[str, Value]) -> Dict[str, Value]: - log(INFO, "Client %d: starting stage 3...", state.sid) - - active_sids = cast(List[int], named_values[KEY_ACTIVE_SECURE_ID_LIST]) - dead_sids = cast(List[int], named_values[KEY_DEAD_SECURE_ID_LIST]) - # Send private mask seed share for every avaliable client (including itclient) - # Send first private key share for building pairwise mask for every dropped client - if len(active_sids) < state.threshold: - raise Exception("Available neighbours number smaller than threshold") - - sids, shares = [], [] - sids += active_sids - shares += [state.rd_seed_share_dict[sid] for sid in active_sids] - sids += dead_sids - shares += [state.sk1_share_dict[sid] for sid in dead_sids] - - log(INFO, "Client %d: stage 3 completes. uploading key shares...", state.sid) - return {KEY_SECURE_ID_LIST: sids, KEY_SHARE_LIST: shares} diff --git a/src/py/flwr/client/secure_aggregation/secaggplus_handler_test.py b/src/py/flwr/client/secure_aggregation/secaggplus_handler_test.py deleted file mode 100644 index 9693a46af989..000000000000 --- a/src/py/flwr/client/secure_aggregation/secaggplus_handler_test.py +++ /dev/null @@ -1,287 +0,0 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""The SecAgg+ protocol handler tests.""" - -import unittest -from itertools import product -from typing import Any, Dict, List, cast - -from flwr.client import NumPyClient -from flwr.common.secure_aggregation.secaggplus_constants import ( - KEY_ACTIVE_SECURE_ID_LIST, - KEY_CIPHERTEXT_LIST, - KEY_CLIPPING_RANGE, - KEY_DEAD_SECURE_ID_LIST, - KEY_MOD_RANGE, - KEY_PARAMETERS, - KEY_SAMPLE_NUMBER, - KEY_SECURE_ID, - KEY_SHARE_NUMBER, - KEY_SOURCE_LIST, - KEY_STAGE, - KEY_TARGET_RANGE, - KEY_THRESHOLD, - STAGE_COLLECT_MASKED_INPUT, - STAGE_SETUP, - STAGE_SHARE_KEYS, - STAGE_UNMASK, - STAGES, -) -from flwr.common.typing import Value - -from .secaggplus_handler import SecAggPlusHandler, check_named_values - - -class EmptyFlowerNumPyClient(NumPyClient, SecAggPlusHandler): - """Empty NumPyClient.""" - - -class TestSecAggPlusHandler(unittest.TestCase): - """Test the SecAgg+ protocol handler.""" - - def test_invalid_handler(self) -> None: - """Test invalid handler.""" - handler = SecAggPlusHandler() - - with self.assertRaises(TypeError): - handler.handle_secure_aggregation({}) - - def test_stage_transition(self) -> None: - """Test stage transition.""" - handler = EmptyFlowerNumPyClient() - - assert STAGES == ( - STAGE_SETUP, - STAGE_SHARE_KEYS, - STAGE_COLLECT_MASKED_INPUT, - STAGE_UNMASK, - ) - - valid_transitions = { - # From one stage to the next stage - (STAGE_UNMASK, STAGE_SETUP), - (STAGE_SETUP, STAGE_SHARE_KEYS), - (STAGE_SHARE_KEYS, STAGE_COLLECT_MASKED_INPUT), - (STAGE_COLLECT_MASKED_INPUT, STAGE_UNMASK), - # From any stage to the initial stage - # Such transitions will log a warning. - (STAGE_SETUP, STAGE_SETUP), - (STAGE_SHARE_KEYS, STAGE_SETUP), - (STAGE_COLLECT_MASKED_INPUT, STAGE_SETUP), - } - - invalid_transitions = set(product(STAGES, STAGES)).difference(valid_transitions) - - # Test valid transitions - # If the next stage is valid, the function should update the current stage - # and then raise KeyError or other exceptions when trying to execute SA. - for current_stage, next_stage in valid_transitions: - # pylint: disable-next=protected-access - handler._current_stage = current_stage - - with self.assertRaises(KeyError): - handler.handle_secure_aggregation({KEY_STAGE: next_stage}) - # pylint: disable-next=protected-access - assert handler._current_stage == next_stage - - # Test invalid transitions - # If the next stage is invalid, the function should raise ValueError - for current_stage, next_stage in invalid_transitions: - # pylint: disable-next=protected-access - handler._current_stage = current_stage - - with self.assertRaises(ValueError): - handler.handle_secure_aggregation({KEY_STAGE: next_stage}) - # pylint: disable-next=protected-access - assert handler._current_stage == current_stage - - def test_stage_setup_check(self) -> None: - """Test content checking for the setup stage.""" - handler = EmptyFlowerNumPyClient() - - valid_key_type_pairs = [ - (KEY_SAMPLE_NUMBER, int), - (KEY_SECURE_ID, int), - (KEY_SHARE_NUMBER, int), - (KEY_THRESHOLD, int), - (KEY_CLIPPING_RANGE, float), - (KEY_TARGET_RANGE, int), - (KEY_MOD_RANGE, int), - ] - - type_to_test_value: Dict[type, Value] = { - int: 10, - bool: True, - float: 1.0, - str: "test", - bytes: b"test", - } - - valid_named_values: Dict[str, Value] = { - key: type_to_test_value[value_type] - for key, value_type in valid_key_type_pairs - } - - # Test valid `named_values` - try: - check_named_values(STAGE_SETUP, valid_named_values.copy()) - # pylint: disable-next=broad-except - except Exception as exc: - self.fail(f"check_named_values() raised {type(exc)} unexpectedly!") - - # Set the stage - valid_named_values[KEY_STAGE] = STAGE_SETUP - - # Test invalid `named_values` - for key, value_type in valid_key_type_pairs: - invalid_named_values = valid_named_values.copy() - - # Test wrong value type for the key - for other_type, other_value in type_to_test_value.items(): - if other_type == value_type: - continue - invalid_named_values[key] = other_value - # pylint: disable-next=protected-access - handler._current_stage = STAGE_UNMASK - with self.assertRaises(TypeError): - handler.handle_secure_aggregation(invalid_named_values.copy()) - - # Test missing key - invalid_named_values.pop(key) - # pylint: disable-next=protected-access - handler._current_stage = STAGE_UNMASK - with self.assertRaises(KeyError): - handler.handle_secure_aggregation(invalid_named_values.copy()) - - def test_stage_share_keys_check(self) -> None: - """Test content checking for the share keys stage.""" - handler = EmptyFlowerNumPyClient() - - valid_named_values: Dict[str, Value] = { - "1": [b"public key 1", b"public key 2"], - "2": [b"public key 1", b"public key 2"], - "3": [b"public key 1", b"public key 2"], - } - - # Test valid `named_values` - try: - check_named_values(STAGE_SHARE_KEYS, valid_named_values.copy()) - # pylint: disable-next=broad-except - except Exception as exc: - self.fail(f"check_named_values() raised {type(exc)} unexpectedly!") - - # Set the stage - valid_named_values[KEY_STAGE] = STAGE_SHARE_KEYS - - # Test invalid `named_values` - invalid_values: List[Value] = [ - b"public key 1", - [b"public key 1"], - [b"public key 1", b"public key 2", b"public key 3"], - ] - - for value in invalid_values: - invalid_named_values = valid_named_values.copy() - invalid_named_values["1"] = value - - # pylint: disable-next=protected-access - handler._current_stage = STAGE_SETUP - with self.assertRaises(TypeError): - handler.handle_secure_aggregation(invalid_named_values.copy()) - - def test_stage_collect_masked_input_check(self) -> None: - """Test content checking for the collect masked input stage.""" - handler = EmptyFlowerNumPyClient() - - valid_named_values: Dict[str, Value] = { - KEY_CIPHERTEXT_LIST: [b"ctxt!", b"ctxt@", b"ctxt#", b"ctxt?"], - KEY_SOURCE_LIST: [32, 51324, 32324123, -3], - KEY_PARAMETERS: [b"params1", b"params2"], - } - - # Test valid `named_values` - try: - check_named_values(STAGE_COLLECT_MASKED_INPUT, valid_named_values.copy()) - # pylint: disable-next=broad-except - except Exception as exc: - self.fail(f"check_named_values() raised {type(exc)} unexpectedly!") - - # Set the stage - valid_named_values[KEY_STAGE] = STAGE_COLLECT_MASKED_INPUT - - # Test invalid `named_values` - # Test missing keys - for key in list(valid_named_values.keys()): - if key == KEY_STAGE: - continue - invalid_named_values = valid_named_values.copy() - invalid_named_values.pop(key) - # pylint: disable-next=protected-access - handler._current_stage = STAGE_SHARE_KEYS - with self.assertRaises(KeyError): - handler.handle_secure_aggregation(invalid_named_values) - - # Test wrong value type for the key - for key in valid_named_values: - if key == KEY_STAGE: - continue - invalid_named_values = valid_named_values.copy() - cast(List[Any], invalid_named_values[key]).append(3.1415926) - # pylint: disable-next=protected-access - handler._current_stage = STAGE_SHARE_KEYS - with self.assertRaises(TypeError): - handler.handle_secure_aggregation(invalid_named_values) - - def test_stage_unmask_check(self) -> None: - """Test content checking for the unmasking stage.""" - handler = EmptyFlowerNumPyClient() - - valid_named_values: Dict[str, Value] = { - KEY_ACTIVE_SECURE_ID_LIST: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], - KEY_DEAD_SECURE_ID_LIST: [32, 51324, 32324123, -3], - } - - # Test valid `named_values` - try: - check_named_values(STAGE_UNMASK, valid_named_values.copy()) - # pylint: disable-next=broad-except - except Exception as exc: - self.fail(f"check_named_values() raised {type(exc)} unexpectedly!") - - # Set the stage - valid_named_values[KEY_STAGE] = STAGE_UNMASK - - # Test invalid `named_values` - # Test missing keys - for key in list(valid_named_values.keys()): - if key == KEY_STAGE: - continue - invalid_named_values = valid_named_values.copy() - invalid_named_values.pop(key) - # pylint: disable-next=protected-access - handler._current_stage = STAGE_COLLECT_MASKED_INPUT - with self.assertRaises(KeyError): - handler.handle_secure_aggregation(invalid_named_values) - - # Test wrong value type for the key - for key in valid_named_values: - if key == KEY_STAGE: - continue - invalid_named_values = valid_named_values.copy() - cast(List[Any], invalid_named_values[key]).append(True) - # pylint: disable-next=protected-access - handler._current_stage = STAGE_COLLECT_MASKED_INPUT - with self.assertRaises(TypeError): - handler.handle_secure_aggregation(invalid_named_values) diff --git a/src/py/flwr/client/typing.py b/src/py/flwr/client/typing.py index 2c1f7506592c..956ac7a15c05 100644 --- a/src/py/flwr/client/typing.py +++ b/src/py/flwr/client/typing.py @@ -14,30 +14,15 @@ # ============================================================================== """Custom types for Flower clients.""" -from dataclasses import dataclass + from typing import Callable -from flwr.client.workload_state import WorkloadState -from flwr.proto.task_pb2 import TaskIns, TaskRes +from flwr.common import Context, Message from .client import Client as Client - -@dataclass -class Fwd: - """.""" - - task_ins: TaskIns - state: WorkloadState - - -@dataclass -class Bwd: - """.""" - - task_res: TaskRes - state: WorkloadState - - -FlowerCallable = Callable[[Fwd], Bwd] +# Compatibility ClientFn = Callable[[str], Client] + +ClientAppCallable = Callable[[Message, Context], Message] +Mod = Callable[[Message, Context, ClientAppCallable], Message] diff --git a/src/py/flwr/common/__init__.py b/src/py/flwr/common/__init__.py index 2f45de45dfc3..9f9ff7ebc68a 100644 --- a/src/py/flwr/common/__init__.py +++ b/src/py/flwr/common/__init__.py @@ -15,14 +15,26 @@ """Common components shared between server and client.""" +from .constant import MessageType as MessageType +from .constant import MessageTypeLegacy as MessageTypeLegacy +from .context import Context as Context from .date import now as now from .grpc import GRPC_MAX_MESSAGE_LENGTH from .logger import configure as configure from .logger import log as log +from .message import Error as Error +from .message import Message as Message +from .message import Metadata as Metadata from .parameter import bytes_to_ndarray as bytes_to_ndarray from .parameter import ndarray_to_bytes as ndarray_to_bytes from .parameter import ndarrays_to_parameters as ndarrays_to_parameters from .parameter import parameters_to_ndarrays as parameters_to_ndarrays +from .record import Array as Array +from .record import ConfigsRecord as ConfigsRecord +from .record import MetricsRecord as MetricsRecord +from .record import ParametersRecord as ParametersRecord +from .record import RecordSet as RecordSet +from .record import array_from_numpy as array_from_numpy from .telemetry import EventType as EventType from .telemetry import event as event from .typing import ClientMessage as ClientMessage @@ -49,11 +61,15 @@ from .typing import Status as Status __all__ = [ + "Array", + "array_from_numpy", "bytes_to_ndarray", "ClientMessage", "Code", "Config", + "ConfigsRecord", "configure", + "Context", "DisconnectRes", "EvaluateIns", "EvaluateRes", @@ -61,14 +77,20 @@ "EventType", "FitIns", "FitRes", + "Error", "GetParametersIns", "GetParametersRes", "GetPropertiesIns", "GetPropertiesRes", "GRPC_MAX_MESSAGE_LENGTH", "log", + "Message", + "MessageType", + "MessageTypeLegacy", + "Metadata", "Metrics", "MetricsAggregationFn", + "MetricsRecord", "ndarray_to_bytes", "now", "NDArray", @@ -76,8 +98,10 @@ "ndarrays_to_parameters", "Parameters", "parameters_to_ndarrays", + "ParametersRecord", "Properties", "ReconnectIns", + "RecordSet", "Scalar", "ServerMessage", "Status", diff --git a/src/py/flwr/common/address_test.py b/src/py/flwr/common/address_test.py index c12dd5fd289e..420b89871d69 100644 --- a/src/py/flwr/common/address_test.py +++ b/src/py/flwr/common/address_test.py @@ -109,10 +109,10 @@ def test_domain_correct() -> None: """Test if a correct domain address is correctly parsed.""" # Prepare addresses = [ - ("flower.dev:123", ("flower.dev", 123, None)), - ("sub.flower.dev:123", ("sub.flower.dev", 123, None)), - ("sub2.sub1.flower.dev:123", ("sub2.sub1.flower.dev", 123, None)), - ("s5.s4.s3.s2.s1.flower.dev:123", ("s5.s4.s3.s2.s1.flower.dev", 123, None)), + ("flower.ai:123", ("flower.ai", 123, None)), + ("sub.flower.ai:123", ("sub.flower.ai", 123, None)), + ("sub2.sub1.flower.ai:123", ("sub2.sub1.flower.ai", 123, None)), + ("s5.s4.s3.s2.s1.flower.ai:123", ("s5.s4.s3.s2.s1.flower.ai", 123, None)), ("localhost:123", ("localhost", 123, None)), ("https://localhost:123", ("https://localhost", 123, None)), ("http://localhost:123", ("http://localhost", 123, None)), @@ -130,8 +130,8 @@ def test_domain_incorrect() -> None: """Test if an incorrect domain address returns None.""" # Prepare addresses = [ - "flower.dev", - "flower.dev:65536", + "flower.ai", + "flower.ai:65536", ] for address in addresses: diff --git a/src/py/flwr/common/constant.py b/src/py/flwr/common/constant.py index 49802f2815be..7d30a10f5881 100644 --- a/src/py/flwr/common/constant.py +++ b/src/py/flwr/common/constant.py @@ -15,6 +15,8 @@ """Flower constants.""" +from __future__ import annotations + MISSING_EXTRA_REST = """ Extra dependencies required for using the REST-based Fleet API are missing. @@ -26,8 +28,43 @@ TRANSPORT_TYPE_GRPC_BIDI = "grpc-bidi" TRANSPORT_TYPE_GRPC_RERE = "grpc-rere" TRANSPORT_TYPE_REST = "rest" +TRANSPORT_TYPE_VCE = "vce" TRANSPORT_TYPES = [ TRANSPORT_TYPE_GRPC_BIDI, TRANSPORT_TYPE_GRPC_RERE, TRANSPORT_TYPE_REST, + TRANSPORT_TYPE_VCE, ] + + +class MessageType: + """Message type.""" + + TRAIN = "train" + EVALUATE = "evaluate" + QUERY = "query" + + def __new__(cls) -> MessageType: + """Prevent instantiation.""" + raise TypeError(f"{cls.__name__} cannot be instantiated.") + + +class MessageTypeLegacy: + """Legacy message type.""" + + GET_PROPERTIES = "get_properties" + GET_PARAMETERS = "get_parameters" + + def __new__(cls) -> MessageTypeLegacy: + """Prevent instantiation.""" + raise TypeError(f"{cls.__name__} cannot be instantiated.") + + +class SType: + """Serialisation type.""" + + NUMPY = "numpy.ndarray" + + def __new__(cls) -> SType: + """Prevent instantiation.""" + raise TypeError(f"{cls.__name__} cannot be instantiated.") diff --git a/src/py/flwr/common/context.py b/src/py/flwr/common/context.py new file mode 100644 index 000000000000..b6349307d150 --- /dev/null +++ b/src/py/flwr/common/context.py @@ -0,0 +1,38 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Context.""" + + +from dataclasses import dataclass + +from .record import RecordSet + + +@dataclass +class Context: + """State of your run. + + Parameters + ---------- + state : RecordSet + Holds records added by the entity in a given run and that will stay local. + This means that the data it holds will never leave the system it's running from. + This can be used as an intermediate storage or scratchpad when + executing mods. It can also be used as a memory to access + at different points during the lifecycle of this entity (e.g. across + multiple rounds) + """ + + state: RecordSet diff --git a/src/py/flwr/common/differential_privacy.py b/src/py/flwr/common/differential_privacy.py new file mode 100644 index 000000000000..85dc198ef8a0 --- /dev/null +++ b/src/py/flwr/common/differential_privacy.py @@ -0,0 +1,176 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility functions for differential privacy.""" + + +from logging import WARNING +from typing import Optional, Tuple + +import numpy as np + +from flwr.common import ( + NDArrays, + Parameters, + ndarrays_to_parameters, + parameters_to_ndarrays, +) +from flwr.common.logger import log + + +def get_norm(input_arrays: NDArrays) -> float: + """Compute the L2 norm of the flattened input.""" + array_norms = [np.linalg.norm(array.flat) for array in input_arrays] + # pylint: disable=consider-using-generator + return float(np.sqrt(sum([norm**2 for norm in array_norms]))) + + +def add_gaussian_noise_inplace(input_arrays: NDArrays, std_dev: float) -> None: + """Add Gaussian noise to each element of the input arrays.""" + for array in input_arrays: + array += np.random.normal(0, std_dev, array.shape) + + +def clip_inputs_inplace(input_arrays: NDArrays, clipping_norm: float) -> None: + """Clip model update based on the clipping norm in-place. + + FlatClip method of the paper: https://arxiv.org/abs/1710.06963 + """ + input_norm = get_norm(input_arrays) + scaling_factor = min(1, clipping_norm / input_norm) + for array in input_arrays: + array *= scaling_factor + + +def compute_stdv( + noise_multiplier: float, clipping_norm: float, num_sampled_clients: int +) -> float: + """Compute standard deviation for noise addition. + + Paper: https://arxiv.org/abs/1710.06963 + """ + return float((noise_multiplier * clipping_norm) / num_sampled_clients) + + +def compute_clip_model_update( + param1: NDArrays, param2: NDArrays, clipping_norm: float +) -> None: + """Compute model update (param1 - param2) and clip it. + + Then add the clipped value to param1.""" + model_update = [np.subtract(x, y) for (x, y) in zip(param1, param2)] + clip_inputs_inplace(model_update, clipping_norm) + + for i, _ in enumerate(param2): + param1[i] = param2[i] + model_update[i] + + +def adaptive_clip_inputs_inplace(input_arrays: NDArrays, clipping_norm: float) -> bool: + """Clip model update based on the clipping norm in-place. + + It returns true if scaling_factor < 1 which is used for norm_bit + FlatClip method of the paper: https://arxiv.org/abs/1710.06963 + """ + input_norm = get_norm(input_arrays) + scaling_factor = min(1, clipping_norm / input_norm) + for array in input_arrays: + array *= scaling_factor + return scaling_factor < 1 + + +def compute_adaptive_clip_model_update( + param1: NDArrays, param2: NDArrays, clipping_norm: float +) -> bool: + """Compute model update, clip it, then add the clipped value to param1. + + model update = param1 - param2 + Return the norm_bit + """ + model_update = [np.subtract(x, y) for (x, y) in zip(param1, param2)] + norm_bit = adaptive_clip_inputs_inplace(model_update, clipping_norm) + + for i, _ in enumerate(param2): + param1[i] = param2[i] + model_update[i] + + return norm_bit + + +def add_gaussian_noise_to_params( + model_params: Parameters, + noise_multiplier: float, + clipping_norm: float, + num_sampled_clients: int, +) -> Parameters: + """Add gaussian noise to model parameters.""" + model_params_ndarrays = parameters_to_ndarrays(model_params) + add_gaussian_noise_inplace( + model_params_ndarrays, + compute_stdv(noise_multiplier, clipping_norm, num_sampled_clients), + ) + return ndarrays_to_parameters(model_params_ndarrays) + + +def compute_adaptive_noise_params( + noise_multiplier: float, + num_sampled_clients: float, + clipped_count_stddev: Optional[float], +) -> Tuple[float, float]: + """Compute noising parameters for the adaptive clipping. + + Paper: https://arxiv.org/abs/1905.03871 + """ + if noise_multiplier > 0: + if clipped_count_stddev is None: + clipped_count_stddev = num_sampled_clients / 20 + if noise_multiplier >= 2 * clipped_count_stddev: + raise ValueError( + f"If not specified, `clipped_count_stddev` is set to " + f"`num_sampled_clients`/20 by default. This value " + f"({num_sampled_clients / 20}) is too low to achieve the " + f"desired effective `noise_multiplier` ({noise_multiplier}). " + f"Consider increasing `clipped_count_stddev` or decreasing " + f"`noise_multiplier`." + ) + noise_multiplier_value = ( + noise_multiplier ** (-2) - (2 * clipped_count_stddev) ** (-2) + ) ** -0.5 + + adding_noise = noise_multiplier_value / noise_multiplier + if adding_noise >= 2: + log( + WARNING, + "A significant amount of noise (%s) has to be " + "added. Consider increasing `clipped_count_stddev` or " + "`num_sampled_clients`.", + adding_noise, + ) + + else: + if clipped_count_stddev is None: + clipped_count_stddev = 0.0 + noise_multiplier_value = 0.0 + + return clipped_count_stddev, noise_multiplier_value + + +def add_localdp_gaussian_noise_to_params( + model_params: Parameters, sensitivity: float, epsilon: float, delta: float +) -> Parameters: + """Add local DP gaussian noise to model parameters.""" + model_params_ndarrays = parameters_to_ndarrays(model_params) + add_gaussian_noise_inplace( + model_params_ndarrays, + sensitivity * np.sqrt(2 * np.log(1.25 / delta)) / epsilon, + ) + return ndarrays_to_parameters(model_params_ndarrays) diff --git a/src/py/flwr/common/differential_privacy_constants.py b/src/py/flwr/common/differential_privacy_constants.py new file mode 100644 index 000000000000..415759dfdf0b --- /dev/null +++ b/src/py/flwr/common/differential_privacy_constants.py @@ -0,0 +1,25 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Constants for differential privacy.""" + + +KEY_CLIPPING_NORM = "clipping_norm" +KEY_NORM_BIT = "norm_bit" +CLIENTS_DISCREPANCY_WARNING = ( + "The number of clients returning parameters (%s)" + " differs from the number of sampled clients (%s)." + " This could impact the differential privacy guarantees," + " potentially leading to privacy leakage or inadequate noise calibration." +) diff --git a/src/py/flwr/common/differential_privacy_test.py b/src/py/flwr/common/differential_privacy_test.py new file mode 100644 index 000000000000..ea2e193df2f3 --- /dev/null +++ b/src/py/flwr/common/differential_privacy_test.py @@ -0,0 +1,209 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Differential Privacy (DP) utility functions tests.""" + + +import numpy as np + +from .differential_privacy import ( + add_gaussian_noise_inplace, + clip_inputs_inplace, + compute_adaptive_noise_params, + compute_clip_model_update, + compute_stdv, + get_norm, +) + + +def test_add_gaussian_noise_inplace() -> None: + """Test add_gaussian_noise_inplace function.""" + # Prepare + update = [np.array([[1.0, 2.0], [3.0, 4.0]]), np.array([[5.0, 6.0], [7.0, 8.0]])] + std_dev = 0.1 + + # Execute + add_gaussian_noise_inplace(update, std_dev) + + # Assert + # Check that the shape of the result is the same as the input + for layer in update: + assert layer.shape == (2, 2) + + # Check that the values have been changed and are not equal to the original update + for layer in update: + assert not np.array_equal( + layer, [[1.0, 2.0], [3.0, 4.0]] + ) and not np.array_equal(layer, [[5.0, 6.0], [7.0, 8.0]]) + + # Check that the noise has been added + for layer in update: + noise_added = ( + layer - np.array([[1.0, 2.0], [3.0, 4.0]]) + if np.array_equal(layer, [[1.0, 2.0], [3.0, 4.0]]) + else layer - np.array([[5.0, 6.0], [7.0, 8.0]]) + ) + assert np.any(np.abs(noise_added) > 0) + + +def test_get_norm() -> None: + """Test get_norm function.""" + # Prepare + update = [np.array([[1, 2], [3, 4]]), np.array([[5, 6], [7, 8]])] + + # Execute + result = get_norm(update) + + expected = float( + np.linalg.norm(np.concatenate([sub_update.flatten() for sub_update in update])) + ) + + # Assert + assert expected == result + + +def test_clip_inputs_inplace() -> None: + """Test clip_inputs_inplace function.""" + # Prepare + updates = [ + np.array([[1.5, -0.5], [2.0, -1.0]]), + np.array([0.5, -0.5]), + np.array([[-0.5, 1.5], [-1.0, 2.0]]), + np.array([-0.5, 0.5]), + ] + clipping_norm = 1.5 + + original_updates = [np.copy(update) for update in updates] + + # Execute + clip_inputs_inplace(updates, clipping_norm) + + # Assert + for updated, original_update in zip(updates, original_updates): + clip_norm = np.linalg.norm(original_update) + assert np.all(updated <= clip_norm) and np.all(updated >= -clip_norm) + + +def test_compute_stdv() -> None: + """Test compute_stdv function.""" + # Prepare + noise_multiplier = 1.0 + clipping_norm = 0.5 + num_sampled_clients = 10 + + # Execute + stdv = compute_stdv(noise_multiplier, clipping_norm, num_sampled_clients) + + # Assert + expected_stdv = float((noise_multiplier * clipping_norm) / num_sampled_clients) + assert stdv == expected_stdv + + +def test_compute_clip_model_update() -> None: + """Test compute_clip_model_update function.""" + # Prepare + param1 = [ + np.array([0.5, 1.5, 2.5]), + np.array([3.5, 4.5, 5.5]), + np.array([6.5, 7.5, 8.5]), + ] + param2 = [ + np.array([1.0, 2.0, 3.0]), + np.array([4.0, 5.0, 6.0]), + np.array([7.0, 8.0, 9.0]), + ] + clipping_norm = 4 + + expected_result = [ + np.array([0.5, 1.5, 2.5]), + np.array([3.5, 4.5, 5.5]), + np.array([6.5, 7.5, 8.5]), + ] + + # Execute + compute_clip_model_update(param1, param2, clipping_norm) + + # Assert + for i, param in enumerate(param1): + np.testing.assert_array_almost_equal(param, expected_result[i]) + + +def test_compute_adaptive_noise_params() -> None: + """Test compute_adaptive_noise_params function.""" + # Test valid input with positive noise_multiplier + noise_multiplier = 1.0 + num_sampled_clients = 100.0 + clipped_count_stddev = None + result = compute_adaptive_noise_params( + noise_multiplier, num_sampled_clients, clipped_count_stddev + ) + + # Assert + assert isinstance(result, tuple) + assert len(result) == 2 + assert result[0] > 0.0 + assert result[1] > 0.0 + + # Test valid input with zero noise_multiplier + noise_multiplier = 0.0 + num_sampled_clients = 50.0 + clipped_count_stddev = None + result = compute_adaptive_noise_params( + noise_multiplier, num_sampled_clients, clipped_count_stddev + ) + + # Assert + assert isinstance(result, tuple) + assert len(result) == 2 + assert result[0] == 0.0 + assert result[1] == 0.0 + + # Test valid input with specified clipped_count_stddev + noise_multiplier = 3.0 + num_sampled_clients = 200.0 + clipped_count_stddev = 5.0 + result = compute_adaptive_noise_params( + noise_multiplier, num_sampled_clients, clipped_count_stddev + ) + + # Assert + assert isinstance(result, tuple) + assert len(result) == 2 + assert result[0] == clipped_count_stddev + assert result[1] > 0.0 + + # Test invalid input with noise_multiplier >= 2 * clipped_count_stddev + noise_multiplier = 10.0 + num_sampled_clients = 100.0 + clipped_count_stddev = None + try: + compute_adaptive_noise_params( + noise_multiplier, num_sampled_clients, clipped_count_stddev + ) + except ValueError: + pass + else: + raise AssertionError("Expected ValueError not raised.") + + # Test intermediate calculation + noise_multiplier = 3.0 + num_sampled_clients = 200.0 + clipped_count_stddev = 5.0 + result = compute_adaptive_noise_params( + noise_multiplier, num_sampled_clients, clipped_count_stddev + ) + temp_value = (noise_multiplier ** (-2) - (2 * clipped_count_stddev) ** (-2)) ** -0.5 + + # Assert + assert np.isclose(result[1], temp_value, rtol=1e-6) diff --git a/src/py/flwr/common/dp.py b/src/py/flwr/common/dp.py index 5030ad34805b..83a72b8ce749 100644 --- a/src/py/flwr/common/dp.py +++ b/src/py/flwr/common/dp.py @@ -19,11 +19,13 @@ import numpy as np +from flwr.common.logger import warn_deprecated_feature from flwr.common.typing import NDArrays # Calculates the L2-norm of a potentially ragged array def _get_update_norm(update: NDArrays) -> float: + warn_deprecated_feature("`_get_update_norm` method") flattened_update = update[0] for i in range(1, len(update)): flattened_update = np.append(flattened_update, update[i]) @@ -32,6 +34,7 @@ def _get_update_norm(update: NDArrays) -> float: def add_gaussian_noise(update: NDArrays, std_dev: float) -> NDArrays: """Add iid Gaussian noise to each floating point value in the update.""" + warn_deprecated_feature("`add_gaussian_noise` method") update_noised = [ layer + np.random.normal(0, std_dev, layer.shape) for layer in update ] @@ -40,6 +43,7 @@ def add_gaussian_noise(update: NDArrays, std_dev: float) -> NDArrays: def clip_by_l2(update: NDArrays, threshold: float) -> Tuple[NDArrays, bool]: """Scales the update so thats its L2 norm is upper-bound to threshold.""" + warn_deprecated_feature("`clip_by_l2` method") update_norm = _get_update_norm(update) scaling_factor = min(1, threshold / update_norm) update_clipped: NDArrays = [layer * scaling_factor for layer in update] diff --git a/src/py/flwr/common/exit_handlers.py b/src/py/flwr/common/exit_handlers.py new file mode 100644 index 000000000000..30750c28a450 --- /dev/null +++ b/src/py/flwr/common/exit_handlers.py @@ -0,0 +1,87 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Common function to register exit handlers for server and client.""" + + +import sys +from signal import SIGINT, SIGTERM, signal +from threading import Thread +from types import FrameType +from typing import List, Optional + +from grpc import Server + +from flwr.common.telemetry import EventType, event + + +def register_exit_handlers( + event_type: EventType, + grpc_servers: Optional[List[Server]] = None, + bckg_threads: Optional[List[Thread]] = None, +) -> None: + """Register exit handlers for `SIGINT` and `SIGTERM` signals. + + Parameters + ---------- + event_type : EventType + The telemetry event that should be logged before exit. + grpc_servers: Optional[List[Server]] (default: None) + An otpional list of gRPC servers that need to be gracefully + terminated before exiting. + bckg_threads: Optional[List[Thread]] (default: None) + An optional list of threads that need to be gracefully + terminated before exiting. + """ + default_handlers = { + SIGINT: None, + SIGTERM: None, + } + + def graceful_exit_handler( # type: ignore + signalnum, + frame: FrameType, # pylint: disable=unused-argument + ) -> None: + """Exit handler to be registered with `signal.signal`. + + When called will reset signal handler to original signal handler from + default_handlers. + """ + # Reset to default handler + signal(signalnum, default_handlers[signalnum]) + + event_res = event(event_type=event_type) + + if grpc_servers is not None: + for grpc_server in grpc_servers: + grpc_server.stop(grace=1) + + if bckg_threads is not None: + for bckg_thread in bckg_threads: + bckg_thread.join() + + # Ensure event has happend + event_res.result() + + # Setup things for graceful exit + sys.exit(0) + + default_handlers[SIGINT] = signal( # type: ignore + SIGINT, + graceful_exit_handler, # type: ignore + ) + default_handlers[SIGTERM] = signal( # type: ignore + SIGTERM, + graceful_exit_handler, # type: ignore + ) diff --git a/src/py/flwr/common/grpc.py b/src/py/flwr/common/grpc.py index 9d0543ea8c75..7d0eba078ab0 100644 --- a/src/py/flwr/common/grpc.py +++ b/src/py/flwr/common/grpc.py @@ -15,7 +15,7 @@ """Utility functions for gRPC.""" -from logging import INFO +from logging import DEBUG from typing import Optional import grpc @@ -49,12 +49,12 @@ def create_channel( if insecure: channel = grpc.insecure_channel(server_address, options=channel_options) - log(INFO, "Opened insecure gRPC connection (no certificates were passed)") + log(DEBUG, "Opened insecure gRPC connection (no certificates were passed)") else: ssl_channel_credentials = grpc.ssl_channel_credentials(root_certificates) channel = grpc.secure_channel( server_address, ssl_channel_credentials, options=channel_options ) - log(INFO, "Opened secure gRPC connection using certificates") + log(DEBUG, "Opened secure gRPC connection using certificates") return channel diff --git a/src/py/flwr/common/logger.py b/src/py/flwr/common/logger.py index 29d1562a86d3..2bc41773ed61 100644 --- a/src/py/flwr/common/logger.py +++ b/src/py/flwr/common/logger.py @@ -18,21 +18,86 @@ import logging from logging import WARN, LogRecord from logging.handlers import HTTPHandler -from typing import Any, Dict, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, Optional, TextIO, Tuple # Create logger LOGGER_NAME = "flwr" FLOWER_LOGGER = logging.getLogger(LOGGER_NAME) FLOWER_LOGGER.setLevel(logging.DEBUG) -DEFAULT_FORMATTER = logging.Formatter( - "%(levelname)s %(name)s %(asctime)s | %(filename)s:%(lineno)d | %(message)s" -) +LOG_COLORS = { + "DEBUG": "\033[94m", # Blue + "INFO": "\033[92m", # Green + "WARNING": "\033[93m", # Yellow + "ERROR": "\033[91m", # Red + "CRITICAL": "\033[95m", # Magenta + "RESET": "\033[0m", # Reset to default +} + +if TYPE_CHECKING: + StreamHandler = logging.StreamHandler[Any] +else: + StreamHandler = logging.StreamHandler + + +class ConsoleHandler(StreamHandler): + """Console handler that allows configurable formatting.""" + + def __init__( + self, + timestamps: bool = False, + json: bool = False, + colored: bool = True, + stream: Optional[TextIO] = None, + ) -> None: + super().__init__(stream) + self.timestamps = timestamps + self.json = json + self.colored = colored + + def emit(self, record: LogRecord) -> None: + """Emit a record.""" + if self.json: + record.message = record.getMessage().replace("\t", "").strip() + + # Check if the message is empty + if not record.message: + return + + super().emit(record) + + def format(self, record: LogRecord) -> str: + """Format function that adds colors to log level.""" + seperator = " " * (8 - len(record.levelname)) + if self.json: + log_fmt = "{lvl='%(levelname)s', time='%(asctime)s', msg='%(message)s'}" + else: + log_fmt = ( + f"{LOG_COLORS[record.levelname] if self.colored else ''}" + f"%(levelname)s {'%(asctime)s' if self.timestamps else ''}" + f"{LOG_COLORS['RESET'] if self.colored else ''}" + f": {seperator} %(message)s" + ) + formatter = logging.Formatter(log_fmt) + return formatter.format(record) + + +def update_console_handler(level: int, timestamps: bool, colored: bool) -> None: + """Update the logging handler.""" + for handler in logging.getLogger(LOGGER_NAME).handlers: + if isinstance(handler, ConsoleHandler): + handler.setLevel(level) + handler.timestamps = timestamps + handler.colored = colored + # Configure console logger -console_handler = logging.StreamHandler() -console_handler.setLevel(logging.DEBUG) -console_handler.setFormatter(DEFAULT_FORMATTER) +console_handler = ConsoleHandler( + timestamps=False, + json=False, + colored=True, +) +console_handler.setLevel(logging.INFO) FLOWER_LOGGER.addHandler(console_handler) @@ -103,11 +168,23 @@ def warn_experimental_feature(name: str) -> None: """Warn the user when they use an experimental feature.""" log( WARN, - """ - EXPERIMENTAL FEATURE: %s + """EXPERIMENTAL FEATURE: %s + + This is an experimental feature. It could change significantly or be removed + entirely in future versions of Flower. + """, + name, + ) + + +def warn_deprecated_feature(name: str) -> None: + """Warn the user when they use a deprecated feature.""" + log( + WARN, + """DEPRECATED FEATURE: %s - This is an experimental feature. It could change significantly or be removed - entirely in future versions of Flower. + This is a deprecated feature. It will be removed + entirely in future versions of Flower. """, name, ) diff --git a/src/py/flwr/common/message.py b/src/py/flwr/common/message.py new file mode 100644 index 000000000000..88cf750f1a94 --- /dev/null +++ b/src/py/flwr/common/message.py @@ -0,0 +1,323 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Message.""" + +from __future__ import annotations + +from dataclasses import dataclass + +from .record import RecordSet + + +@dataclass +class Metadata: # pylint: disable=too-many-instance-attributes + """A dataclass holding metadata associated with the current message. + + Parameters + ---------- + run_id : int + An identifier for the current run. + message_id : str + An identifier for the current message. + src_node_id : int + An identifier for the node sending this message. + dst_node_id : int + An identifier for the node receiving this message. + reply_to_message : str + An identifier for the message this message replies to. + group_id : str + An identifier for grouping messages. In some settings, + this is used as the FL round. + ttl : str + Time-to-live for this message. + message_type : str + A string that encodes the action to be executed on + the receiving end. + partition_id : Optional[int] + An identifier that can be used when loading a particular + data partition for a ClientApp. Making use of this identifier + is more relevant when conducting simulations. + """ + + _run_id: int + _message_id: str + _src_node_id: int + _dst_node_id: int + _reply_to_message: str + _group_id: str + _ttl: str + _message_type: str + _partition_id: int | None + + def __init__( # pylint: disable=too-many-arguments + self, + run_id: int, + message_id: str, + src_node_id: int, + dst_node_id: int, + reply_to_message: str, + group_id: str, + ttl: str, + message_type: str, + partition_id: int | None = None, + ) -> None: + self._run_id = run_id + self._message_id = message_id + self._src_node_id = src_node_id + self._dst_node_id = dst_node_id + self._reply_to_message = reply_to_message + self._group_id = group_id + self._ttl = ttl + self._message_type = message_type + self._partition_id = partition_id + + @property + def run_id(self) -> int: + """An identifier for the current run.""" + return self._run_id + + @property + def message_id(self) -> str: + """An identifier for the current message.""" + return self._message_id + + @property + def src_node_id(self) -> int: + """An identifier for the node sending this message.""" + return self._src_node_id + + @property + def reply_to_message(self) -> str: + """An identifier for the message this message replies to.""" + return self._reply_to_message + + @property + def dst_node_id(self) -> int: + """An identifier for the node receiving this message.""" + return self._dst_node_id + + @dst_node_id.setter + def dst_node_id(self, value: int) -> None: + """Set dst_node_id.""" + self._dst_node_id = value + + @property + def group_id(self) -> str: + """An identifier for grouping messages.""" + return self._group_id + + @group_id.setter + def group_id(self, value: str) -> None: + """Set group_id.""" + self._group_id = value + + @property + def ttl(self) -> str: + """Time-to-live for this message.""" + return self._ttl + + @ttl.setter + def ttl(self, value: str) -> None: + """Set ttl.""" + self._ttl = value + + @property + def message_type(self) -> str: + """A string that encodes the action to be executed on the receiving end.""" + return self._message_type + + @message_type.setter + def message_type(self, value: str) -> None: + """Set message_type.""" + self._message_type = value + + @property + def partition_id(self) -> int | None: + """An identifier telling which data partition a ClientApp should use.""" + return self._partition_id + + @partition_id.setter + def partition_id(self, value: int) -> None: + """Set patition_id.""" + self._partition_id = value + + +@dataclass +class Error: + """A dataclass that stores information about an error that occurred. + + Parameters + ---------- + code : int + An identifier for the error. + reason : Optional[str] + A reason for why the error arose (e.g. an exception stack-trace) + """ + + _code: int + _reason: str | None = None + + def __init__(self, code: int, reason: str | None = None) -> None: + self._code = code + self._reason = reason + + @property + def code(self) -> int: + """Error code.""" + return self._code + + @property + def reason(self) -> str | None: + """Reason reported about the error.""" + return self._reason + + +@dataclass +class Message: + """State of your application from the viewpoint of the entity using it. + + Parameters + ---------- + metadata : Metadata + A dataclass including information about the message to be executed. + content : Optional[RecordSet] + Holds records either sent by another entity (e.g. sent by the server-side + logic to a client, or vice-versa) or that will be sent to it. + error : Optional[Error] + A dataclass that captures information about an error that took place + when processing another message. + """ + + _metadata: Metadata + _content: RecordSet | None = None + _error: Error | None = None + + def __init__( + self, + metadata: Metadata, + content: RecordSet | None = None, + error: Error | None = None, + ) -> None: + self._metadata = metadata + + if not (content is None) ^ (error is None): + raise ValueError("Either `content` or `error` must be set, but not both.") + + self._content = content + self._error = error + + @property + def metadata(self) -> Metadata: + """A dataclass including information about the message to be executed.""" + return self._metadata + + @property + def content(self) -> RecordSet: + """The content of this message.""" + if self._content is None: + raise ValueError( + "Message content is None. Use .has_content() " + "to check if a message has content." + ) + return self._content + + @content.setter + def content(self, value: RecordSet) -> None: + """Set content.""" + if self._error is None: + self._content = value + else: + raise ValueError("A message with an error set cannot have content.") + + @property + def error(self) -> Error: + """Error captured by this message.""" + if self._error is None: + raise ValueError( + "Message error is None. Use .has_error() " + "to check first if a message carries an error." + ) + return self._error + + @error.setter + def error(self, value: Error) -> None: + """Set error.""" + if self.has_content(): + raise ValueError("A message with content set cannot carry an error.") + self._error = value + + def has_content(self) -> bool: + """Return True if message has content, else False.""" + return self._content is not None + + def has_error(self) -> bool: + """Return True if message has an error, else False.""" + return self._error is not None + + def _create_reply_metadata(self, ttl: str) -> Metadata: + """Construct metadata for a reply message.""" + return Metadata( + run_id=self.metadata.run_id, + message_id="", + src_node_id=self.metadata.dst_node_id, + dst_node_id=self.metadata.src_node_id, + reply_to_message=self.metadata.message_id, + group_id=self.metadata.group_id, + ttl=ttl, + message_type=self.metadata.message_type, + partition_id=self.metadata.partition_id, + ) + + def create_error_reply( + self, + error: Error, + ttl: str, + ) -> Message: + """Construct a reply message indicating an error happened. + + Parameters + ---------- + error : Error + The error that was encountered. + ttl : str + Time-to-live for this message. + """ + # Create reply with error + message = Message(metadata=self._create_reply_metadata(ttl), error=error) + return message + + def create_reply(self, content: RecordSet, ttl: str) -> Message: + """Create a reply to this message with specified content and TTL. + + The method generates a new `Message` as a reply to this message. + It inherits 'run_id', 'src_node_id', 'dst_node_id', and 'message_type' from + this message and sets 'reply_to_message' to the ID of this message. + + Parameters + ---------- + content : RecordSet + The content for the reply message. + ttl : str + Time-to-live for this message. + + Returns + ------- + Message + A new `Message` instance representing the reply. + """ + return Message( + metadata=self._create_reply_metadata(ttl), + content=content, + ) diff --git a/src/py/flwr/common/message_test.py b/src/py/flwr/common/message_test.py new file mode 100644 index 000000000000..ba628bb3235a --- /dev/null +++ b/src/py/flwr/common/message_test.py @@ -0,0 +1,109 @@ +# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Message tests.""" + + +from contextlib import ExitStack +from typing import Any, Callable + +import pytest + +# pylint: enable=E0611 +from . import RecordSet +from .message import Error, Message +from .serde_test import RecordMaker + + +@pytest.mark.parametrize( + "content_fn, error_fn, context", + [ + ( + lambda maker: maker.recordset(1, 1, 1), + None, + None, + ), # check when only content is set + (None, lambda code: Error(code=code), None), # check when only error is set + ( + lambda maker: maker.recordset(1, 1, 1), + lambda code: Error(code=code), + pytest.raises(ValueError), + ), # check when both are set (ERROR) + (None, None, pytest.raises(ValueError)), # check when neither is set (ERROR) + ], +) +def test_message_creation( + content_fn: Callable[ + [ + RecordMaker, + ], + RecordSet, + ], + error_fn: Callable[[int], Error], + context: Any, +) -> None: + """Test Message creation attempting to pass content and/or error.""" + # Prepare + maker = RecordMaker(state=2) + metadata = maker.metadata() + + with ExitStack() as stack: + if context: + stack.enter_context(context) + + _ = Message( + metadata=metadata, + content=None if content_fn is None else content_fn(maker), + error=None if error_fn is None else error_fn(0), + ) + + +def create_message_with_content() -> Message: + """Create a Message with content.""" + maker = RecordMaker(state=2) + metadata = maker.metadata() + return Message(metadata=metadata, content=RecordSet()) + + +def create_message_with_error() -> Message: + """Create a Message with error.""" + maker = RecordMaker(state=2) + metadata = maker.metadata() + return Message(metadata=metadata, error=Error(code=1)) + + +@pytest.mark.parametrize( + "message_creation_fn", + [ + create_message_with_content, + create_message_with_error, + ], +) +def test_altering_message( + message_creation_fn: Callable[ + [], + Message, + ], +) -> None: + """Test that a message with content doesn't allow setting an error. + + And viceversa. + """ + message = message_creation_fn() + + with pytest.raises(ValueError): + if message.has_content(): + message.error = Error(code=123) + if message.has_error(): + message.content = RecordSet() diff --git a/src/py/flwr/common/object_ref.py b/src/py/flwr/common/object_ref.py new file mode 100644 index 000000000000..4660f07e24a4 --- /dev/null +++ b/src/py/flwr/common/object_ref.py @@ -0,0 +1,140 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Helper functions to load objects from a reference.""" + + +import ast +import importlib +from importlib.util import find_spec +from typing import Any, Optional, Tuple, Type + +OBJECT_REF_HELP_STR = """ +\n\nThe object reference string should have the form :. Valid +examples include `client:app` and `project.package.module:wrapper.app`. It must +refer to a module on the PYTHONPATH and the module needs to have the specified +attribute. +""" + + +def validate( + module_attribute_str: str, +) -> Tuple[bool, Optional[str]]: + """Validate object reference. + + The object reference string should have the form :. Valid + examples include `client:app` and `project.package.module:wrapper.app`. It must + refer to a module on the PYTHONPATH and the module needs to have the specified + attribute. + + Returns + ------- + Tuple[bool, Optional[str]] + A boolean indicating whether an object reference is valid and + the reason why it might not be. + """ + module_str, _, attributes_str = module_attribute_str.partition(":") + if not module_str: + return ( + False, + f"Missing module in {module_attribute_str}{OBJECT_REF_HELP_STR}", + ) + if not attributes_str: + return ( + False, + f"Missing attribute in {module_attribute_str}{OBJECT_REF_HELP_STR}", + ) + + # Load module + module = find_spec(module_str) + if module and module.origin: + if not _find_attribute_in_module(module.origin, attributes_str): + return ( + False, + f"Unable to find attribute {attributes_str} in module {module_str}" + f"{OBJECT_REF_HELP_STR}", + ) + return (True, None) + + return ( + False, + f"Unable to load module {module_str}{OBJECT_REF_HELP_STR}", + ) + + +def load_app( + module_attribute_str: str, + error_type: Type[Exception], +) -> Any: + """Return the object specified in a module attribute string. + + The module/attribute string should have the form :. Valid + examples include `client:app` and `project.package.module:wrapper.app`. It must + refer to a module on the PYTHONPATH, the module needs to have the specified + attribute. + """ + valid, error_msg = validate(module_attribute_str) + if not valid and error_msg: + raise error_type(error_msg) from None + + module_str, _, attributes_str = module_attribute_str.partition(":") + + try: + module = importlib.import_module(module_str) + except ModuleNotFoundError: + raise error_type( + f"Unable to load module {module_str}{OBJECT_REF_HELP_STR}", + ) from None + + # Recursively load attribute + attribute = module + try: + for attribute_str in attributes_str.split("."): + attribute = getattr(attribute, attribute_str) + except AttributeError: + raise error_type( + f"Unable to load attribute {attributes_str} from module {module_str}" + f"{OBJECT_REF_HELP_STR}", + ) from None + + return attribute + + +def _find_attribute_in_module(file_path: str, attribute_name: str) -> bool: + """Check if attribute_name exists in module's abstract symbolic tree.""" + with open(file_path, encoding="utf-8") as file: + node = ast.parse(file.read(), filename=file_path) + + for n in ast.walk(node): + if isinstance(n, ast.Assign): + for target in n.targets: + if isinstance(target, ast.Name) and target.id == attribute_name: + return True + if _is_module_in_all(attribute_name, target, n): + return True + return False + + +def _is_module_in_all(attribute_name: str, target: ast.expr, n: ast.Assign) -> bool: + """Now check if attribute_name is in __all__.""" + if isinstance(target, ast.Name) and target.id == "__all__": + if isinstance(n.value, ast.List): + for elt in n.value.elts: + if isinstance(elt, ast.Str) and elt.s == attribute_name: + return True + elif isinstance(n.value, ast.Tuple): + for elt in n.value.elts: + if isinstance(elt, ast.Str) and elt.s == attribute_name: + return True + return False diff --git a/src/py/flwr/common/object_ref_test.py b/src/py/flwr/common/object_ref_test.py new file mode 100644 index 000000000000..f4513a4319ab --- /dev/null +++ b/src/py/flwr/common/object_ref_test.py @@ -0,0 +1,46 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the validation function of object refs.""" + +from .object_ref import OBJECT_REF_HELP_STR, validate + + +def test_validate_object_reference() -> None: + """Test that validate_object_reference succeeds correctly.""" + # Prepare + ref = "flwr.cli.run:run" + + # Execute + is_valid, error = validate(ref) + + # Assert + assert is_valid + assert error is None + + +def test_validate_object_reference_fails() -> None: + """Test that validate_object_reference fails correctly.""" + # Prepare + ref = "flwr.cli.run:runa" + + # Execute + is_valid, error = validate(ref) + + # Assert + assert not is_valid + assert ( + error + == f"Unable to find attribute runa in module flwr.cli.run{OBJECT_REF_HELP_STR}" + ) diff --git a/src/py/flwr/common/pyproject.py b/src/py/flwr/common/pyproject.py new file mode 100644 index 000000000000..66585e422397 --- /dev/null +++ b/src/py/flwr/common/pyproject.py @@ -0,0 +1,41 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Validates the project's name property.""" + +import re + + +def validate_project_name(name: str) -> bool: + """Validate the project name against PEP 621 and PEP 503 specifications. + + Conventions at a glance: + - Must be lowercase + - Must not contain special characters + - Must use hyphens(recommended) or underscores. No spaces. + - Recommended to be no more than 40 characters long (But it can be) + + Parameters + ---------- + name : str + The project name to validate. + + Returns + ------- + bool + True if the name is valid, False otherwise. + """ + if not name or len(name) > 40 or not re.match(r"^[a-z0-9-_]+$", name): + return False + return True diff --git a/src/py/flwr/common/pyproject_test.py b/src/py/flwr/common/pyproject_test.py new file mode 100644 index 000000000000..88a945054b83 --- /dev/null +++ b/src/py/flwr/common/pyproject_test.py @@ -0,0 +1,108 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the function that validates name property.""" + +from .pyproject import validate_project_name + + +# Happy Flow +def test_valid_name_with_lower_case() -> None: + """Test a valid single-word project name with all lower case.""" + # Prepare + name = "myproject" + expected = True + # Execute + actual = validate_project_name(name) + # Assert + assert actual == expected, f"Expected {name} to be valid" + + +def test_valid_name_with_dashes() -> None: + """Test a valid project name with hyphens inbetween.""" + # Prepare + name = "valid-project-name" + expected = True + # Execute + actual = validate_project_name(name) + # Assert + assert actual == expected, f"Expected {name} to be valid" + + +def test_valid_name_with_underscores() -> None: + """Test a valid project name with underscores inbetween.""" + # Prepare + name = "valid_project_name" + expected = True + # Execute + actual = validate_project_name(name) + # Assert + assert actual == expected, f"Expected {name} to be valid" + + +def test_invalid_name_with_upper_letters() -> None: + """Tests a project name with Spaces and Uppercase letter.""" + # Prepare + name = "Invalid Project Name" + expected = False + # Execute + actual = validate_project_name(name) + # Assert + assert actual == expected, "Upper Case and Spaces are not allowed" + + +def test_name_with_spaces() -> None: + """Tests a project name with spaces inbetween.""" + # Prepare + name = "name with spaces" + expected = False + # Execute + actual = validate_project_name(name) + # Assert + assert actual == expected, "Spaces are not allowed" + + +def test_empty_name() -> None: + """Tests use-case for an empty project name.""" + # Prepare + name = "" + expected = False + # Execute + actual = validate_project_name(name) + # Assert + assert actual == expected, "Empty name is not valid" + + +def test_long_name() -> None: + """Tests for long project names.""" + # Prepare + name = "a" * 41 + expected = False + # Execute + actual = validate_project_name(name) + # Assert + # It can be more than 40 but generally + # it is recommended not to be more than 40 + assert actual == expected, "Name longer than 40 characters is not recommended" + + +def test_name_with_special_characters() -> None: + """Tests for project names with special characters.""" + # Prepare + name = "name!@#" + expected = False + # Execute + actual = validate_project_name(name) + # Assert + assert actual == expected, "Special characters are not allowed" diff --git a/src/py/flwr/common/record/__init__.py b/src/py/flwr/common/record/__init__.py new file mode 100644 index 000000000000..60bc54b8552a --- /dev/null +++ b/src/py/flwr/common/record/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Record APIs.""" + +from .configsrecord import ConfigsRecord +from .conversion_utils import array_from_numpy +from .metricsrecord import MetricsRecord +from .parametersrecord import Array, ParametersRecord +from .recordset import RecordSet + +__all__ = [ + "Array", + "array_from_numpy", + "ConfigsRecord", + "MetricsRecord", + "ParametersRecord", + "RecordSet", +] diff --git a/src/py/flwr/common/record/configsrecord.py b/src/py/flwr/common/record/configsrecord.py new file mode 100644 index 000000000000..471c85f0b961 --- /dev/null +++ b/src/py/flwr/common/record/configsrecord.py @@ -0,0 +1,123 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""ConfigsRecord.""" + + +from typing import Dict, List, Optional, get_args + +from flwr.common.typing import ConfigsRecordValues, ConfigsScalar + +from .typeddict import TypedDict + + +def _check_key(key: str) -> None: + """Check if key is of expected type.""" + if not isinstance(key, str): + raise TypeError(f"Key must be of type `str` but `{type(key)}` was passed.") + + +def _check_value(value: ConfigsRecordValues) -> None: + def is_valid(__v: ConfigsScalar) -> None: + """Check if value is of expected type.""" + if not isinstance(__v, get_args(ConfigsScalar)): + raise TypeError( + "Not all values are of valid type." + f" Expected `{ConfigsRecordValues}` but `{type(__v)}` was passed." + ) + + if isinstance(value, list): + # If your lists are large (e.g. 1M+ elements) this will be slow + # 1s to check 10M element list on a M2 Pro + # In such settings, you'd be better of treating such config as + # an array and pass it to a ParametersRecord. + # Empty lists are valid + if len(value) > 0: + is_valid(value[0]) + # all elements in the list must be of the same valid type + # this is needed for protobuf + value_type = type(value[0]) + if not all(isinstance(v, value_type) for v in value): + raise TypeError( + "All values in a list must be of the same valid type. " + f"One of {ConfigsScalar}." + ) + else: + is_valid(value) + + +class ConfigsRecord(TypedDict[str, ConfigsRecordValues]): + """Configs record.""" + + def __init__( + self, + configs_dict: Optional[Dict[str, ConfigsRecordValues]] = None, + keep_input: bool = True, + ) -> None: + """Construct a ConfigsRecord object. + + Parameters + ---------- + configs_dict : Optional[Dict[str, ConfigsRecordValues]] + A dictionary that stores basic types (i.e. `str`, `int`, `float`, `bytes` as + defined in `ConfigsScalar`) and lists of such types (see + `ConfigsScalarList`). + keep_input : bool (default: True) + A boolean indicating whether config passed should be deleted from the input + dictionary immediately after adding them to the record. When set + to True, the data is duplicated in memory. If memory is a concern, set + it to False. + """ + super().__init__(_check_key, _check_value) + if configs_dict: + for k in list(configs_dict.keys()): + self[k] = configs_dict[k] + if not keep_input: + del configs_dict[k] + + def count_bytes(self) -> int: + """Return number of Bytes stored in this object. + + This function counts booleans as occupying 1 Byte. + """ + + def get_var_bytes(value: ConfigsScalar) -> int: + """Return Bytes of value passed.""" + if isinstance(value, bool): + var_bytes = 1 + elif isinstance(value, (int, float)): + var_bytes = ( + 8 # the profobufing represents int/floats in ConfigRecords as 64bit + ) + if isinstance(value, (str, bytes)): + var_bytes = len(value) + return var_bytes + + num_bytes = 0 + + for k, v in self.items(): + if isinstance(v, List): + if isinstance(v[0], (bytes, str)): + # not all str are of equal length necessarily + # for both the footprint of each element is 1 Byte + num_bytes += int(sum(len(s) for s in v)) # type: ignore + else: + num_bytes += get_var_bytes(v[0]) * len(v) + else: + num_bytes += get_var_bytes(v) + + # We also count the bytes footprint of the keys + num_bytes += len(k) + + return num_bytes diff --git a/src/py/flwr/common/record/conversion_utils.py b/src/py/flwr/common/record/conversion_utils.py new file mode 100644 index 000000000000..7cc0b04283e9 --- /dev/null +++ b/src/py/flwr/common/record/conversion_utils.py @@ -0,0 +1,40 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Conversion utility functions for Records.""" + + +from io import BytesIO + +import numpy as np + +from ..constant import SType +from ..typing import NDArray +from .parametersrecord import Array + + +def array_from_numpy(ndarray: NDArray) -> Array: + """Create Array from NumPy ndarray.""" + buffer = BytesIO() + # WARNING: NEVER set allow_pickle to true. + # Reason: loading pickled data can execute arbitrary code + # Source: https://numpy.org/doc/stable/reference/generated/numpy.save.html + np.save(buffer, ndarray, allow_pickle=False) + data = buffer.getvalue() + return Array( + dtype=str(ndarray.dtype), + shape=list(ndarray.shape), + stype=SType.NUMPY, + data=data, + ) diff --git a/src/py/flwr/common/record/conversion_utils_test.py b/src/py/flwr/common/record/conversion_utils_test.py new file mode 100644 index 000000000000..84be37fda4a3 --- /dev/null +++ b/src/py/flwr/common/record/conversion_utils_test.py @@ -0,0 +1,44 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit tests for conversion utility functions.""" + + +import unittest +from io import BytesIO + +import numpy as np + +from ..constant import SType +from .conversion_utils import array_from_numpy + + +class TestArrayFromNumpy(unittest.TestCase): + """Unit tests for array_from_numpy.""" + + def test_array_from_numpy(self) -> None: + """Test the array_from_numpy function.""" + # Prepare + original_array = np.array([1, 2, 3], dtype=np.float32) + + # Execute + array_instance = array_from_numpy(original_array) + buffer = BytesIO(array_instance.data) + deserialized_array = np.load(buffer, allow_pickle=False) + + # Assert + self.assertEqual(array_instance.dtype, str(original_array.dtype)) + self.assertEqual(array_instance.shape, list(original_array.shape)) + self.assertEqual(array_instance.stype, SType.NUMPY) + np.testing.assert_array_equal(deserialized_array, original_array) diff --git a/src/py/flwr/common/record/metricsrecord.py b/src/py/flwr/common/record/metricsrecord.py new file mode 100644 index 000000000000..2b6e584be390 --- /dev/null +++ b/src/py/flwr/common/record/metricsrecord.py @@ -0,0 +1,102 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""MetricsRecord.""" + + +from typing import Dict, List, Optional, get_args + +from flwr.common.typing import MetricsRecordValues, MetricsScalar + +from .typeddict import TypedDict + + +def _check_key(key: str) -> None: + """Check if key is of expected type.""" + if not isinstance(key, str): + raise TypeError(f"Key must be of type `str` but `{type(key)}` was passed.") + + +def _check_value(value: MetricsRecordValues) -> None: + def is_valid(__v: MetricsScalar) -> None: + """Check if value is of expected type.""" + if not isinstance(__v, get_args(MetricsScalar)) or isinstance(__v, bool): + raise TypeError( + "Not all values are of valid type." + f" Expected `{MetricsRecordValues}` but `{type(__v)}` was passed." + ) + + if isinstance(value, list): + # If your lists are large (e.g. 1M+ elements) this will be slow + # 1s to check 10M element list on a M2 Pro + # In such settings, you'd be better of treating such metric as + # an array and pass it to a ParametersRecord. + # Empty lists are valid + if len(value) > 0: + is_valid(value[0]) + # all elements in the list must be of the same valid type + # this is needed for protobuf + value_type = type(value[0]) + if not all(isinstance(v, value_type) for v in value): + raise TypeError( + "All values in a list must be of the same valid type. " + f"One of {MetricsScalar}." + ) + else: + is_valid(value) + + +class MetricsRecord(TypedDict[str, MetricsRecordValues]): + """Metrics record.""" + + def __init__( + self, + metrics_dict: Optional[Dict[str, MetricsRecordValues]] = None, + keep_input: bool = True, + ): + """Construct a MetricsRecord object. + + Parameters + ---------- + metrics_dict : Optional[Dict[str, MetricsRecordValues]] + A dictionary that stores basic types (i.e. `int`, `float` as defined + in `MetricsScalar`) and list of such types (see `MetricsScalarList`). + keep_input : bool (default: True) + A boolean indicating whether metrics should be deleted from the input + dictionary immediately after adding them to the record. When set + to True, the data is duplicated in memory. If memory is a concern, set + it to False. + """ + super().__init__(_check_key, _check_value) + if metrics_dict: + for k in list(metrics_dict.keys()): + self[k] = metrics_dict[k] + if not keep_input: + del metrics_dict[k] + + def count_bytes(self) -> int: + """Return number of Bytes stored in this object.""" + num_bytes = 0 + + for k, v in self.items(): + if isinstance(v, List): + # both int and float normally take 4 bytes + # But MetricRecords are mapped to 64bit int/float + # during protobuffing + num_bytes += 8 * len(v) + else: + num_bytes += 8 + # We also count the bytes footprint of the keys + num_bytes += len(k) + return num_bytes diff --git a/src/py/flwr/common/record/parametersrecord.py b/src/py/flwr/common/record/parametersrecord.py new file mode 100644 index 000000000000..a4a71f751f97 --- /dev/null +++ b/src/py/flwr/common/record/parametersrecord.py @@ -0,0 +1,136 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""ParametersRecord and Array.""" + +from dataclasses import dataclass +from io import BytesIO +from typing import List, Optional, OrderedDict, cast + +import numpy as np + +from ..constant import SType +from ..typing import NDArray +from .typeddict import TypedDict + + +@dataclass +class Array: + """Array type. + + A dataclass containing serialized data from an array-like or tensor-like object + along with some metadata about it. + + Parameters + ---------- + dtype : str + A string representing the data type of the serialised object (e.g. `np.float32`) + + shape : List[int] + A list representing the shape of the unserialized array-like object. This is + used to deserialize the data (depending on the serialization method) or simply + as a metadata field. + + stype : str + A string indicating the type of serialisation mechanism used to generate the + bytes in `data` from an array-like or tensor-like object. + + data: bytes + A buffer of bytes containing the data. + """ + + dtype: str + shape: List[int] + stype: str + data: bytes + + def numpy(self) -> NDArray: + """Return the array as a NumPy array.""" + if self.stype != SType.NUMPY: + raise TypeError( + f"Unsupported serialization type for numpy conversion: '{self.stype}'" + ) + bytes_io = BytesIO(self.data) + # WARNING: NEVER set allow_pickle to true. + # Reason: loading pickled data can execute arbitrary code + # Source: https://numpy.org/doc/stable/reference/generated/numpy.load.html + ndarray_deserialized = np.load(bytes_io, allow_pickle=False) + return cast(NDArray, ndarray_deserialized) + + +def _check_key(key: str) -> None: + """Check if key is of expected type.""" + if not isinstance(key, str): + raise TypeError(f"Key must be of type `str` but `{type(key)}` was passed.") + + +def _check_value(value: Array) -> None: + if not isinstance(value, Array): + raise TypeError( + f"Value must be of type `{Array}` but `{type(value)}` was passed." + ) + + +@dataclass +class ParametersRecord(TypedDict[str, Array]): + """Parameters record. + + A dataclass storing named Arrays in order. This means that it holds entries as an + OrderedDict[str, Array]. ParametersRecord objects can be viewed as an equivalent to + PyTorch's state_dict, but holding serialised tensors instead. + """ + + def __init__( + self, + array_dict: Optional[OrderedDict[str, Array]] = None, + keep_input: bool = False, + ) -> None: + """Construct a ParametersRecord object. + + Parameters + ---------- + array_dict : Optional[OrderedDict[str, Array]] + A dictionary that stores serialized array-like or tensor-like objects. + keep_input : bool (default: False) + A boolean indicating whether parameters should be deleted from the input + dictionary immediately after adding them to the record. If False, the + dictionary passed to `set_parameters()` will be empty once exiting from that + function. This is the desired behaviour when working with very large + models/tensors/arrays. However, if you plan to continue working with your + parameters after adding it to the record, set this flag to True. When set + to True, the data is duplicated in memory. + """ + super().__init__(_check_key, _check_value) + if array_dict: + for k in list(array_dict.keys()): + self[k] = array_dict[k] + if not keep_input: + del array_dict[k] + + def count_bytes(self) -> int: + """Return number of Bytes stored in this object. + + Note that a small amount of Bytes might also be included in this counting that + correspond to metadata of the serialized object (e.g. of NumPy array) needed for + deseralization. + """ + num_bytes = 0 + + for k, v in self.items(): + num_bytes += len(v.data) + + # We also count the bytes footprint of the keys + num_bytes += len(k) + + return num_bytes diff --git a/src/py/flwr/common/record/parametersrecord_test.py b/src/py/flwr/common/record/parametersrecord_test.py new file mode 100644 index 000000000000..e840e5e266e4 --- /dev/null +++ b/src/py/flwr/common/record/parametersrecord_test.py @@ -0,0 +1,101 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit tests for ParametersRecord and Array.""" + +import unittest +from collections import OrderedDict +from io import BytesIO +from typing import List + +import numpy as np +import pytest + +from flwr.common import ndarray_to_bytes + +from ..constant import SType +from ..typing import NDArray +from .parametersrecord import Array, ParametersRecord + + +def _get_buffer_from_ndarray(array: NDArray) -> bytes: + """Return a bytes buffer froma given NumPy array.""" + buffer = BytesIO() + np.save(buffer, array, allow_pickle=False) + return buffer.getvalue() + + +class TestArray(unittest.TestCase): + """Unit tests for Array.""" + + def test_numpy_conversion_valid(self) -> None: + """Test the numpy method with valid Array instance.""" + # Prepare + original_array = np.array([1, 2, 3], dtype=np.float32) + + buffer = _get_buffer_from_ndarray(original_array) + + # Execute + array_instance = Array( + dtype=str(original_array.dtype), + shape=list(original_array.shape), + stype=SType.NUMPY, + data=buffer, + ) + converted_array = array_instance.numpy() + + # Assert + np.testing.assert_array_equal(converted_array, original_array) + + def test_numpy_conversion_invalid(self) -> None: + """Test the numpy method with invalid Array instance.""" + # Prepare + array_instance = Array( + dtype="float32", + shape=[3], + stype="invalid_stype", # Non-numpy stype + data=b"", + ) + + # Execute and assert + with self.assertRaises(TypeError): + array_instance.numpy() + + +@pytest.mark.parametrize( + "shape, dtype", + [ + ([100], "float32"), + ([31, 31], "int8"), + ([31, 153], "bool_"), # bool_ is represented as a whole Byte in NumPy + ], +) +def test_count_bytes(shape: List[int], dtype: str) -> None: + """Test bytes in a ParametersRecord are computed correctly.""" + original_array = np.random.randn(*shape).astype(np.dtype(dtype)) + + buff = ndarray_to_bytes(original_array) + + buffer = _get_buffer_from_ndarray(original_array) + + array_instance = Array( + dtype=str(original_array.dtype), + shape=list(original_array.shape), + stype=SType.NUMPY, + data=buffer, + ) + key_name = "data" + p_record = ParametersRecord(OrderedDict({key_name: array_instance})) + + assert len(buff) + len(key_name) == p_record.count_bytes() diff --git a/src/py/flwr/common/record/recordset.py b/src/py/flwr/common/record/recordset.py new file mode 100644 index 000000000000..d8ef44ab15c2 --- /dev/null +++ b/src/py/flwr/common/record/recordset.py @@ -0,0 +1,79 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""RecordSet.""" + + +from dataclasses import dataclass +from typing import Callable, Dict, Optional, Type, TypeVar + +from .configsrecord import ConfigsRecord +from .metricsrecord import MetricsRecord +from .parametersrecord import ParametersRecord +from .typeddict import TypedDict + +T = TypeVar("T") + + +@dataclass +class RecordSet: + """RecordSet stores groups of parameters, metrics and configs.""" + + _parameters_records: TypedDict[str, ParametersRecord] + _metrics_records: TypedDict[str, MetricsRecord] + _configs_records: TypedDict[str, ConfigsRecord] + + def __init__( + self, + parameters_records: Optional[Dict[str, ParametersRecord]] = None, + metrics_records: Optional[Dict[str, MetricsRecord]] = None, + configs_records: Optional[Dict[str, ConfigsRecord]] = None, + ) -> None: + def _get_check_fn(__t: Type[T]) -> Callable[[T], None]: + def _check_fn(__v: T) -> None: + if not isinstance(__v, __t): + raise TypeError(f"Expected `{__t}`, but `{type(__v)}` was passed.") + + return _check_fn + + self._parameters_records = TypedDict[str, ParametersRecord]( + _get_check_fn(str), _get_check_fn(ParametersRecord) + ) + self._metrics_records = TypedDict[str, MetricsRecord]( + _get_check_fn(str), _get_check_fn(MetricsRecord) + ) + self._configs_records = TypedDict[str, ConfigsRecord]( + _get_check_fn(str), _get_check_fn(ConfigsRecord) + ) + if parameters_records is not None: + self._parameters_records.update(parameters_records) + if metrics_records is not None: + self._metrics_records.update(metrics_records) + if configs_records is not None: + self._configs_records.update(configs_records) + + @property + def parameters_records(self) -> TypedDict[str, ParametersRecord]: + """Dictionary holding ParametersRecord instances.""" + return self._parameters_records + + @property + def metrics_records(self) -> TypedDict[str, MetricsRecord]: + """Dictionary holding MetricsRecord instances.""" + return self._metrics_records + + @property + def configs_records(self) -> TypedDict[str, ConfigsRecord]: + """Dictionary holding ConfigsRecord instances.""" + return self._configs_records diff --git a/src/py/flwr/common/record/recordset_test.py b/src/py/flwr/common/record/recordset_test.py new file mode 100644 index 000000000000..0e0b149881be --- /dev/null +++ b/src/py/flwr/common/record/recordset_test.py @@ -0,0 +1,400 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""RecordSet tests.""" + +from copy import deepcopy +from typing import Callable, Dict, List, OrderedDict, Type, Union + +import numpy as np +import pytest + +from flwr.common.parameter import ndarrays_to_parameters, parameters_to_ndarrays +from flwr.common.recordset_compat import ( + parameters_to_parametersrecord, + parametersrecord_to_parameters, +) +from flwr.common.typing import ( + ConfigsRecordValues, + MetricsRecordValues, + NDArray, + NDArrays, + Parameters, +) + +from . import Array, ConfigsRecord, MetricsRecord, ParametersRecord + + +def get_ndarrays() -> NDArrays: + """Return list of NumPy arrays.""" + arr1 = np.array([[1.0, 2.0], [3.0, 4], [5.0, 6.0]]) + arr2 = np.eye(2, 7, 3) + + return [arr1, arr2] + + +def ndarray_to_array(ndarray: NDArray) -> Array: + """Represent NumPy ndarray as Array.""" + return Array( + data=ndarray.tobytes(), + dtype=str(ndarray.dtype), + stype="numpy.ndarray.tobytes", + shape=list(ndarray.shape), + ) + + +def test_ndarray_to_array() -> None: + """Test creation of Array object from NumPy ndarray.""" + shape = (2, 7, 9) + arr = np.eye(*shape) + + array = ndarray_to_array(arr) + + arr_ = np.frombuffer(buffer=array.data, dtype=array.dtype).reshape(array.shape) + + assert np.array_equal(arr, arr_) + + +def test_parameters_to_array_and_back() -> None: + """Test conversion between legacy Parameters and Array.""" + ndarrays = get_ndarrays() + + # Array represents a single array, unlike parameters, which represent a + # list of arrays + ndarray = ndarrays[0] + + parameters = ndarrays_to_parameters([ndarray]) + + array = Array( + data=parameters.tensors[0], dtype="", stype=parameters.tensor_type, shape=[] + ) + + parameters = Parameters(tensors=[array.data], tensor_type=array.stype) + + ndarray_ = parameters_to_ndarrays(parameters=parameters)[0] + + assert np.array_equal(ndarray, ndarray_) + + +@pytest.mark.parametrize( + "keep_input, validate_freed_fn", + [ + (False, lambda x, x_copy, y: len(x.tensors) == 0), # check tensors were freed + (True, lambda x, x_copy, y: x.tensors == y.tensors), # check they are equal + ], +) +def test_parameters_to_parametersrecord_and_back( + keep_input: bool, + validate_freed_fn: Callable[[Parameters, Parameters, Parameters], bool], +) -> None: + """Test conversion between legacy Parameters and ParametersRecords.""" + ndarrays = get_ndarrays() + + parameters = ndarrays_to_parameters(ndarrays) + parameters_copy = deepcopy(parameters) + + params_record = parameters_to_parametersrecord( + parameters=parameters, keep_input=keep_input + ) + + parameters_ = parametersrecord_to_parameters(params_record, keep_input=keep_input) + + ndarrays_ = parameters_to_ndarrays(parameters=parameters_) + + # Validate returned NDArrays match those at the beginning + for arr, arr_ in zip(ndarrays, ndarrays_): + assert np.array_equal(arr, arr_), "no" + + # Validate initial Parameters object has been handled according to `keep_input` + assert validate_freed_fn(parameters, parameters_copy, parameters_) + + +def test_set_parameters_while_keeping_intputs() -> None: + """Tests keep_input functionality in ParametersRecord.""" + # Adding parameters to a record that doesn't erase entries in the input `array_dict` + array_dict = OrderedDict( + {str(i): ndarray_to_array(ndarray) for i, ndarray in enumerate(get_ndarrays())} + ) + p_record = ParametersRecord(array_dict, keep_input=True) + + # Creating a second parametersrecord passing the same `array_dict` (not erased) + p_record_2 = ParametersRecord(array_dict) + assert p_record == p_record_2 + + # Now it should be empty (the second ParametersRecord wasn't flagged to keep it) + assert len(array_dict) == 0 + + +def test_set_parameters_with_correct_types() -> None: + """Test adding dictionary of Arrays to ParametersRecord.""" + p_record = ParametersRecord() + array_dict = OrderedDict( + {str(i): ndarray_to_array(ndarray) for i, ndarray in enumerate(get_ndarrays())} + ) + p_record.update(array_dict) + + +@pytest.mark.parametrize( + "key_type, value_fn", + [ + (str, lambda x: x), # correct key, incorrect value + (str, lambda x: x.tolist()), # correct key, incorrect value + (int, ndarray_to_array), # incorrect key, correct value + (int, lambda x: x), # incorrect key, incorrect value + (int, lambda x: x.tolist()), # incorrect key, incorrect value + ], +) +def test_set_parameters_with_incorrect_types( + key_type: Type[Union[int, str]], + value_fn: Callable[[NDArray], Union[NDArray, List[float]]], +) -> None: + """Test adding dictionary of unsupported types to ParametersRecord.""" + p_record = ParametersRecord() + + array_dict = { + key_type(i): value_fn(ndarray) for i, ndarray in enumerate(get_ndarrays()) + } + + with pytest.raises(TypeError): + p_record.update(array_dict) + + +@pytest.mark.parametrize( + "key_type, value_fn", + [ + (str, lambda x: int(x.flatten()[0])), # str: int + (str, lambda x: float(x.flatten()[0])), # str: float + (str, lambda x: x.flatten().astype("int").tolist()), # str: List[int] + (str, lambda x: x.flatten().astype("float").tolist()), # str: List[float] + (str, lambda x: []), # str: empty list + ], +) +def test_set_metrics_to_metricsrecord_with_correct_types( + key_type: Type[str], + value_fn: Callable[[NDArray], MetricsRecordValues], +) -> None: + """Test adding metrics of various types to a MetricsRecord.""" + m_record = MetricsRecord() + + labels = [1, 2.0] + arrays = get_ndarrays() + + my_metrics = OrderedDict( + {key_type(label): value_fn(arr) for label, arr in zip(labels, arrays)} + ) + + # Add metric + m_record.update(my_metrics) + + # Check metrics are actually added + assert my_metrics == m_record + + +@pytest.mark.parametrize( + "key_type, value_fn", + [ + (str, lambda x: str(x.flatten()[0])), # str: str (supported: unsupported) + (str, lambda x: bool(x.flatten()[0])), # str: bool (supported: unsupported) + ( + str, + lambda x: x.flatten().astype("str").tolist(), + ), # str: List[str] (supported: unsupported) + (str, lambda x: x), # str: NDArray (supported: unsupported) + ( + str, + lambda x: {str(v): v for v in x.flatten()}, + ), # str: dict[str: float] (supported: unsupported) + ( + str, + lambda x: [{str(v): v for v in x.flatten()}], + ), # str: List[dict[str: float]] (supported: unsupported) + ( + str, + lambda x: [1, 2.0, 3.0, 4], + ), # str: List[mixing valid types] (supported: unsupported) + ( + int, + lambda x: x.flatten().tolist(), + ), # int: List[str] (unsupported: supported) + ( + float, + lambda x: x.flatten().tolist(), + ), # float: List[int] (unsupported: supported) + ], +) +def test_set_metrics_to_metricsrecord_with_incorrect_types( + key_type: Type[Union[str, int, float, bool]], + value_fn: Callable[[NDArray], Union[NDArray, Dict[str, NDArray], List[float]]], +) -> None: + """Test adding metrics of various unsupported types to a MetricsRecord.""" + m_record = MetricsRecord() + + labels = [1, 2.0] + arrays = get_ndarrays() + + my_metrics = OrderedDict( + {key_type(label): value_fn(arr) for label, arr in zip(labels, arrays)} + ) + + with pytest.raises(TypeError): + m_record.update(my_metrics) + + +@pytest.mark.parametrize( + "keep_input", + [ + (True), + (False), + ], +) +def test_set_metrics_to_metricsrecord_with_and_without_keeping_input( + keep_input: bool, +) -> None: + """Test keep_input functionality for MetricsRecord.""" + # constructing a valid input + labels = [1, 2.0] + arrays = get_ndarrays() + my_metrics = OrderedDict( + {str(label): arr.flatten().tolist() for label, arr in zip(labels, arrays)} + ) + + my_metrics_copy = my_metrics.copy() + + # Add metric + m_record = MetricsRecord(my_metrics, keep_input=keep_input) + + # Check metrics are actually added + # Check that input dict has been emptied when enabled such behaviour + if keep_input: + assert my_metrics == m_record + else: + assert my_metrics_copy == m_record + assert len(my_metrics) == 0 + + +@pytest.mark.parametrize( + "key_type, value_fn", + [ + (str, lambda x: str(x.flatten()[0])), # str: str + (str, lambda x: int(x.flatten()[0])), # str: int + (str, lambda x: float(x.flatten()[0])), # str: float + (str, lambda x: bool(x.flatten()[0])), # str: bool + (str, lambda x: x.flatten().tobytes()), # str: bytes + (str, lambda x: x.flatten().astype("str").tolist()), # str: List[str] + (str, lambda x: x.flatten().astype("int").tolist()), # str: List[int] + (str, lambda x: x.flatten().astype("float").tolist()), # str: List[float] + (str, lambda x: x.flatten().astype("bool").tolist()), # str: List[bool] + (str, lambda x: [x.flatten().tobytes()]), # str: List[bytes] + (str, lambda x: []), # str: emptyt list + ], +) +def test_set_configs_to_configsrecord_with_correct_types( + key_type: Type[str], + value_fn: Callable[[NDArray], ConfigsRecordValues], +) -> None: + """Test adding configs of various types to a ConfigsRecord.""" + labels = [1, 2.0] + arrays = get_ndarrays() + + my_configs = OrderedDict( + {key_type(label): value_fn(arr) for label, arr in zip(labels, arrays)} + ) + + c_record = ConfigsRecord(my_configs) + + # check values are actually there + assert c_record == my_configs + + +@pytest.mark.parametrize( + "key_type, value_fn", + [ + (str, lambda x: x), # str: NDArray (supported: unsupported) + ( + str, + lambda x: {str(v): v for v in x.flatten()}, + ), # str: dict[str: float] (supported: unsupported) + ( + str, + lambda x: [{str(v): v for v in x.flatten()}], + ), # str: List[dict[str: float]] (supported: unsupported) + ( + str, + lambda x: [1, 2.0, 3.0, 4], + ), # str: List[mixing valid types] (supported: unsupported) + ( + int, + lambda x: x.flatten().tolist(), + ), # int: List[str] (unsupported: supported) + ( + float, + lambda x: x.flatten().tolist(), + ), # float: List[int] (unsupported: supported) + ], +) +def test_set_configs_to_configsrecord_with_incorrect_types( + key_type: Type[Union[str, int, float]], + value_fn: Callable[[NDArray], Union[NDArray, Dict[str, NDArray], List[float]]], +) -> None: + """Test adding configs of various unsupported types to a ConfigsRecord.""" + c_record = ConfigsRecord() + + labels = [1, 2.0] + arrays = get_ndarrays() + + my_configs = OrderedDict( + {key_type(label): value_fn(arr) for label, arr in zip(labels, arrays)} + ) + + with pytest.raises(TypeError): + c_record.update(my_configs) + + +def test_count_bytes_metricsrecord() -> None: + """Test counting bytes in MetricsRecord.""" + data = {"a": 1, "b": 2.0, "c": [1, 2, 3], "d": [1.0, 2.0, 3.0, 4.0, 5.0]} + bytes_in_dict = 8 + 8 + 3 * 8 + 5 * 8 + bytes_in_dict += 4 # represnting the keys + + m_record = MetricsRecord() + m_record.update(OrderedDict(data)) + record_bytest_count = m_record.count_bytes() + assert bytes_in_dict == record_bytest_count + + +def test_count_bytes_configsrecord() -> None: + """Test counting bytes in ConfigsRecord.""" + data = {"a": 1, "b": 2.0, "c": [1, 2, 3], "d": [1.0, 2.0, 3.0, 4.0, 5.0]} + bytes_in_dict = 8 + 8 + 3 * 8 + 5 * 8 + bytes_in_dict += 4 # represnting the keys + + to_add = { + "aa": True, + "bb": "False", + "cc": bytes(9), + "dd": [True, False, False], + "ee": ["True", "False"], + "ff": [bytes(1), bytes(13), bytes(51)], + } + data = {**data, **to_add} + bytes_in_dict += 1 + 5 + 9 + 3 + (4 + 5) + (1 + 13 + 51) + bytes_in_dict += 12 # represnting the keys + + bytes_in_dict = int(bytes_in_dict) + + c_record = ConfigsRecord() + c_record.update(OrderedDict(data)) + + record_bytest_count = c_record.count_bytes() + assert bytes_in_dict == record_bytest_count diff --git a/src/py/flwr/common/record/typeddict.py b/src/py/flwr/common/record/typeddict.py new file mode 100644 index 000000000000..23d70dc4f7e8 --- /dev/null +++ b/src/py/flwr/common/record/typeddict.py @@ -0,0 +1,113 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Typed dict base class for *Records.""" + + +from typing import Any, Callable, Dict, Generic, Iterator, Tuple, TypeVar, cast + +K = TypeVar("K") # Key type +V = TypeVar("V") # Value type + + +class TypedDict(Generic[K, V]): + """Typed dictionary.""" + + def __init__( + self, check_key_fn: Callable[[K], None], check_value_fn: Callable[[V], None] + ): + self._data: Dict[K, V] = {} + self._check_key_fn = check_key_fn + self._check_value_fn = check_value_fn + + def __setitem__(self, key: K, value: V) -> None: + """Set the given key to the given value after type checking.""" + # Check the types of key and value + self._check_key_fn(key) + self._check_value_fn(value) + # Set key-value pair + self._data[key] = value + + def __delitem__(self, key: K) -> None: + """Remove the item with the specified key.""" + del self._data[key] + + def __getitem__(self, item: K) -> V: + """Return the value for the specified key.""" + return self._data[item] + + def __iter__(self) -> Iterator[K]: + """Yield an iterator over the keys of the dictionary.""" + return iter(self._data) + + def __repr__(self) -> str: + """Return a string representation of the dictionary.""" + return self._data.__repr__() + + def __len__(self) -> int: + """Return the number of items in the dictionary.""" + return len(self._data) + + def __contains__(self, key: K) -> bool: + """Check if the dictionary contains the specified key.""" + return key in self._data + + def __eq__(self, other: object) -> bool: + """Compare this instance to another dictionary or TypedDict.""" + if isinstance(other, TypedDict): + return self._data == other._data + if isinstance(other, dict): + return self._data == other + return NotImplemented + + def items(self) -> Iterator[Tuple[K, V]]: + """R.items() -> a set-like object providing a view on R's items.""" + return cast(Iterator[Tuple[K, V]], self._data.items()) + + def keys(self) -> Iterator[K]: + """R.keys() -> a set-like object providing a view on R's keys.""" + return cast(Iterator[K], self._data.keys()) + + def values(self) -> Iterator[V]: + """R.values() -> an object providing a view on R's values.""" + return cast(Iterator[V], self._data.values()) + + def update(self, *args: Any, **kwargs: Any) -> None: + """R.update([E, ]**F) -> None. + + Update R from dict/iterable E and F. + """ + for key, value in dict(*args, **kwargs).items(): + self[key] = value + + def pop(self, key: K) -> V: + """R.pop(k[,d]) -> v, remove specified key and return the corresponding value. + + If key is not found, d is returned if given, otherwise KeyError is raised. + """ + return self._data.pop(key) + + def get(self, key: K, default: V) -> V: + """R.get(k[,d]) -> R[k] if k in R, else d. + + d defaults to None. + """ + return self._data.get(key, default) + + def clear(self) -> None: + """R.clear() -> None. + + Remove all items from R. + """ + self._data.clear() diff --git a/src/py/flwr/common/recordset_compat.py b/src/py/flwr/common/recordset_compat.py new file mode 100644 index 000000000000..394ea1353bab --- /dev/null +++ b/src/py/flwr/common/recordset_compat.py @@ -0,0 +1,399 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""RecordSet utilities.""" + + +from typing import Dict, Mapping, OrderedDict, Tuple, Union, cast, get_args + +from . import Array, ConfigsRecord, MetricsRecord, ParametersRecord, RecordSet +from .typing import ( + Code, + ConfigsRecordValues, + EvaluateIns, + EvaluateRes, + FitIns, + FitRes, + GetParametersIns, + GetParametersRes, + GetPropertiesIns, + GetPropertiesRes, + MetricsRecordValues, + Parameters, + Scalar, + Status, +) + + +def parametersrecord_to_parameters( + record: ParametersRecord, keep_input: bool +) -> Parameters: + """Convert ParameterRecord to legacy Parameters. + + Warnings + -------- + Because `Arrays` in `ParametersRecord` encode more information of the + array-like or tensor-like data (e.g their datatype, shape) than `Parameters` it + might not be possible to reconstruct such data structures from `Parameters` objects + alone. Additional information or metadata must be provided from elsewhere. + + Parameters + ---------- + record : ParametersRecord + The record to be conveted into Parameters. + keep_input : bool + A boolean indicating whether entries in the record should be deleted from the + input dictionary immediately after adding them to the record. + """ + parameters = Parameters(tensors=[], tensor_type="") + + for key in list(record.keys()): + parameters.tensors.append(record[key].data) + + if not parameters.tensor_type: + # Setting from first array in record. Recall the warning in the docstrings + # of this function. + parameters.tensor_type = record[key].stype + + if not keep_input: + del record[key] + + return parameters + + +def parameters_to_parametersrecord( + parameters: Parameters, keep_input: bool +) -> ParametersRecord: + """Convert legacy Parameters into a single ParametersRecord. + + Because there is no concept of names in the legacy Parameters, arbitrary keys will + be used when constructing the ParametersRecord. Similarly, the shape and data type + won't be recorded in the Array objects. + + Parameters + ---------- + parameters : Parameters + Parameters object to be represented as a ParametersRecord. + keep_input : bool + A boolean indicating whether parameters should be deleted from the input + Parameters object (i.e. a list of serialized NumPy arrays) immediately after + adding them to the record. + """ + tensor_type = parameters.tensor_type + + num_arrays = len(parameters.tensors) + ordered_dict = OrderedDict() + for idx in range(num_arrays): + if keep_input: + tensor = parameters.tensors[idx] + else: + tensor = parameters.tensors.pop(0) + ordered_dict[str(idx)] = Array( + data=tensor, dtype="", stype=tensor_type, shape=[] + ) + + return ParametersRecord(ordered_dict, keep_input=keep_input) + + +def _check_mapping_from_recordscalartype_to_scalar( + record_data: Mapping[str, Union[ConfigsRecordValues, MetricsRecordValues]] +) -> Dict[str, Scalar]: + """Check mapping `common.*RecordValues` into `common.Scalar` is possible.""" + for value in record_data.values(): + if not isinstance(value, get_args(Scalar)): + raise TypeError( + "There is not a 1:1 mapping between `common.Scalar` types and those " + "supported in `common.ConfigsRecordValues` or " + "`common.ConfigsRecordValues`. Consider casting your values to a type " + "supported by the `common.RecordSet` infrastructure. " + f"You used type: {type(value)}" + ) + return cast(Dict[str, Scalar], record_data) + + +def _recordset_to_fit_or_evaluate_ins_components( + recordset: RecordSet, + ins_str: str, + keep_input: bool, +) -> Tuple[Parameters, Dict[str, Scalar]]: + """Derive Fit/Evaluate Ins from a RecordSet.""" + # get Array and construct Parameters + parameters_record = recordset.parameters_records[f"{ins_str}.parameters"] + + parameters = parametersrecord_to_parameters( + parameters_record, keep_input=keep_input + ) + + # get config dict + config_record = recordset.configs_records[f"{ins_str}.config"] + # pylint: disable-next=protected-access + config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record._data) + + return parameters, config_dict + + +def _fit_or_evaluate_ins_to_recordset( + ins: Union[FitIns, EvaluateIns], keep_input: bool +) -> RecordSet: + recordset = RecordSet() + + ins_str = "fitins" if isinstance(ins, FitIns) else "evaluateins" + parametersrecord = parameters_to_parametersrecord(ins.parameters, keep_input) + recordset.parameters_records[f"{ins_str}.parameters"] = parametersrecord + + recordset.configs_records[f"{ins_str}.config"] = ConfigsRecord( + ins.config # type: ignore + ) + + return recordset + + +def _embed_status_into_recordset( + res_str: str, status: Status, recordset: RecordSet +) -> RecordSet: + status_dict: Dict[str, ConfigsRecordValues] = { + "code": int(status.code.value), + "message": status.message, + } + # we add it to a `ConfigsRecord`` because the `status.message`` is a string + # and `str` values aren't supported in `MetricsRecords` + recordset.configs_records[f"{res_str}.status"] = ConfigsRecord(status_dict) + return recordset + + +def _extract_status_from_recordset(res_str: str, recordset: RecordSet) -> Status: + status = recordset.configs_records[f"{res_str}.status"] + code = cast(int, status["code"]) + return Status(code=Code(code), message=str(status["message"])) + + +def recordset_to_fitins(recordset: RecordSet, keep_input: bool) -> FitIns: + """Derive FitIns from a RecordSet object.""" + parameters, config = _recordset_to_fit_or_evaluate_ins_components( + recordset, + ins_str="fitins", + keep_input=keep_input, + ) + + return FitIns(parameters=parameters, config=config) + + +def fitins_to_recordset(fitins: FitIns, keep_input: bool) -> RecordSet: + """Construct a RecordSet from a FitIns object.""" + return _fit_or_evaluate_ins_to_recordset(fitins, keep_input) + + +def recordset_to_fitres(recordset: RecordSet, keep_input: bool) -> FitRes: + """Derive FitRes from a RecordSet object.""" + ins_str = "fitres" + parameters = parametersrecord_to_parameters( + recordset.parameters_records[f"{ins_str}.parameters"], keep_input=keep_input + ) + + num_examples = cast( + int, recordset.metrics_records[f"{ins_str}.num_examples"]["num_examples"] + ) + configs_record = recordset.configs_records[f"{ins_str}.metrics"] + # pylint: disable-next=protected-access + metrics = _check_mapping_from_recordscalartype_to_scalar(configs_record._data) + status = _extract_status_from_recordset(ins_str, recordset) + + return FitRes( + status=status, parameters=parameters, num_examples=num_examples, metrics=metrics + ) + + +def fitres_to_recordset(fitres: FitRes, keep_input: bool) -> RecordSet: + """Construct a RecordSet from a FitRes object.""" + recordset = RecordSet() + + res_str = "fitres" + + recordset.configs_records[f"{res_str}.metrics"] = ConfigsRecord( + fitres.metrics # type: ignore + ) + recordset.metrics_records[f"{res_str}.num_examples"] = MetricsRecord( + {"num_examples": fitres.num_examples}, + ) + recordset.parameters_records[f"{res_str}.parameters"] = ( + parameters_to_parametersrecord( + fitres.parameters, + keep_input, + ) + ) + + # status + recordset = _embed_status_into_recordset(res_str, fitres.status, recordset) + + return recordset + + +def recordset_to_evaluateins(recordset: RecordSet, keep_input: bool) -> EvaluateIns: + """Derive EvaluateIns from a RecordSet object.""" + parameters, config = _recordset_to_fit_or_evaluate_ins_components( + recordset, + ins_str="evaluateins", + keep_input=keep_input, + ) + + return EvaluateIns(parameters=parameters, config=config) + + +def evaluateins_to_recordset(evaluateins: EvaluateIns, keep_input: bool) -> RecordSet: + """Construct a RecordSet from a EvaluateIns object.""" + return _fit_or_evaluate_ins_to_recordset(evaluateins, keep_input) + + +def recordset_to_evaluateres(recordset: RecordSet) -> EvaluateRes: + """Derive EvaluateRes from a RecordSet object.""" + ins_str = "evaluateres" + + loss = cast(int, recordset.metrics_records[f"{ins_str}.loss"]["loss"]) + + num_examples = cast( + int, recordset.metrics_records[f"{ins_str}.num_examples"]["num_examples"] + ) + configs_record = recordset.configs_records[f"{ins_str}.metrics"] + + # pylint: disable-next=protected-access + metrics = _check_mapping_from_recordscalartype_to_scalar(configs_record._data) + status = _extract_status_from_recordset(ins_str, recordset) + + return EvaluateRes( + status=status, loss=loss, num_examples=num_examples, metrics=metrics + ) + + +def evaluateres_to_recordset(evaluateres: EvaluateRes) -> RecordSet: + """Construct a RecordSet from a EvaluateRes object.""" + recordset = RecordSet() + + res_str = "evaluateres" + # loss + recordset.metrics_records[f"{res_str}.loss"] = MetricsRecord( + {"loss": evaluateres.loss}, + ) + + # num_examples + recordset.metrics_records[f"{res_str}.num_examples"] = MetricsRecord( + {"num_examples": evaluateres.num_examples}, + ) + + # metrics + recordset.configs_records[f"{res_str}.metrics"] = ConfigsRecord( + evaluateres.metrics, # type: ignore + ) + + # status + recordset = _embed_status_into_recordset( + f"{res_str}", evaluateres.status, recordset + ) + + return recordset + + +def recordset_to_getparametersins(recordset: RecordSet) -> GetParametersIns: + """Derive GetParametersIns from a RecordSet object.""" + config_record = recordset.configs_records["getparametersins.config"] + # pylint: disable-next=protected-access + config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record._data) + + return GetParametersIns(config=config_dict) + + +def getparametersins_to_recordset(getparameters_ins: GetParametersIns) -> RecordSet: + """Construct a RecordSet from a GetParametersIns object.""" + recordset = RecordSet() + + recordset.configs_records["getparametersins.config"] = ConfigsRecord( + getparameters_ins.config, # type: ignore + ) + return recordset + + +def getparametersres_to_recordset( + getparametersres: GetParametersRes, keep_input: bool +) -> RecordSet: + """Construct a RecordSet from a GetParametersRes object.""" + recordset = RecordSet() + res_str = "getparametersres" + parameters_record = parameters_to_parametersrecord( + getparametersres.parameters, keep_input=keep_input + ) + recordset.parameters_records[f"{res_str}.parameters"] = parameters_record + + # status + recordset = _embed_status_into_recordset( + res_str, getparametersres.status, recordset + ) + + return recordset + + +def recordset_to_getparametersres( + recordset: RecordSet, keep_input: bool +) -> GetParametersRes: + """Derive GetParametersRes from a RecordSet object.""" + res_str = "getparametersres" + parameters = parametersrecord_to_parameters( + recordset.parameters_records[f"{res_str}.parameters"], keep_input=keep_input + ) + + status = _extract_status_from_recordset(res_str, recordset) + return GetParametersRes(status=status, parameters=parameters) + + +def recordset_to_getpropertiesins(recordset: RecordSet) -> GetPropertiesIns: + """Derive GetPropertiesIns from a RecordSet object.""" + config_record = recordset.configs_records["getpropertiesins.config"] + # pylint: disable-next=protected-access + config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record._data) + + return GetPropertiesIns(config=config_dict) + + +def getpropertiesins_to_recordset(getpropertiesins: GetPropertiesIns) -> RecordSet: + """Construct a RecordSet from a GetPropertiesRes object.""" + recordset = RecordSet() + recordset.configs_records["getpropertiesins.config"] = ConfigsRecord( + getpropertiesins.config, # type: ignore + ) + return recordset + + +def recordset_to_getpropertiesres(recordset: RecordSet) -> GetPropertiesRes: + """Derive GetPropertiesRes from a RecordSet object.""" + res_str = "getpropertiesres" + config_record = recordset.configs_records[f"{res_str}.properties"] + # pylint: disable-next=protected-access + properties = _check_mapping_from_recordscalartype_to_scalar(config_record._data) + + status = _extract_status_from_recordset(res_str, recordset=recordset) + + return GetPropertiesRes(status=status, properties=properties) + + +def getpropertiesres_to_recordset(getpropertiesres: GetPropertiesRes) -> RecordSet: + """Construct a RecordSet from a GetPropertiesRes object.""" + recordset = RecordSet() + res_str = "getpropertiesres" + recordset.configs_records[f"{res_str}.properties"] = ConfigsRecord( + getpropertiesres.properties, # type: ignore + ) + # status + recordset = _embed_status_into_recordset( + res_str, getpropertiesres.status, recordset + ) + + return recordset diff --git a/src/py/flwr/common/recordset_compat_test.py b/src/py/flwr/common/recordset_compat_test.py new file mode 100644 index 000000000000..288326dc9e83 --- /dev/null +++ b/src/py/flwr/common/recordset_compat_test.py @@ -0,0 +1,303 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""RecordSet from legacy messages tests.""" + +from copy import deepcopy +from typing import Callable, Dict + +import numpy as np +import pytest + +from .parameter import ndarrays_to_parameters +from .recordset_compat import ( + evaluateins_to_recordset, + evaluateres_to_recordset, + fitins_to_recordset, + fitres_to_recordset, + getparametersins_to_recordset, + getparametersres_to_recordset, + getpropertiesins_to_recordset, + getpropertiesres_to_recordset, + recordset_to_evaluateins, + recordset_to_evaluateres, + recordset_to_fitins, + recordset_to_fitres, + recordset_to_getparametersins, + recordset_to_getparametersres, + recordset_to_getpropertiesins, + recordset_to_getpropertiesres, +) +from .typing import ( + Code, + EvaluateIns, + EvaluateRes, + FitIns, + FitRes, + GetParametersIns, + GetParametersRes, + GetPropertiesIns, + GetPropertiesRes, + NDArrays, + Scalar, + Status, +) + + +def get_ndarrays() -> NDArrays: + """Return list of NumPy arrays.""" + arr1 = np.array([[1.0, 2.0], [3.0, 4], [5.0, 6.0]]) + arr2 = np.eye(2, 7, 3) + + return [arr1, arr2] + + +################################################## +# Testing conversion: *Ins --> RecordSet --> *Ins +# Testing conversion: *Res <-- RecordSet <-- *Res +################################################## + + +def _get_valid_fitins() -> FitIns: + arrays = get_ndarrays() + return FitIns(parameters=ndarrays_to_parameters(arrays), config={"a": 1.0, "b": 0}) + + +def _get_valid_fitres() -> FitRes: + """Returnn Valid parameters but potentially invalid config.""" + arrays = get_ndarrays() + metrics: Dict[str, Scalar] = {"a": 1.0, "b": 0} + return FitRes( + parameters=ndarrays_to_parameters(arrays), + num_examples=1, + status=Status(code=Code(0), message=""), + metrics=metrics, + ) + + +def _get_valid_evaluateins() -> EvaluateIns: + fit_ins = _get_valid_fitins() + return EvaluateIns(parameters=fit_ins.parameters, config=fit_ins.config) + + +def _get_valid_evaluateres() -> EvaluateRes: + """Return potentially invalid config.""" + metrics: Dict[str, Scalar] = {"a": 1.0, "b": 0} + return EvaluateRes( + num_examples=1, + loss=0.1, + status=Status(code=Code(0), message=""), + metrics=metrics, + ) + + +def _get_valid_getparametersins() -> GetParametersIns: + config_dict: Dict[str, Scalar] = { + "a": 1.0, + "b": 3, + "c": True, + } # valid since both Ins/Res communicate over ConfigsRecord + + return GetParametersIns(config_dict) + + +def _get_valid_getparametersres() -> GetParametersRes: + arrays = get_ndarrays() + return GetParametersRes( + status=Status(code=Code(0), message=""), + parameters=ndarrays_to_parameters(arrays), + ) + + +def _get_valid_getpropertiesins() -> GetPropertiesIns: + getparamsins = _get_valid_getparametersins() + return GetPropertiesIns(config=getparamsins.config) + + +def _get_valid_getpropertiesres() -> GetPropertiesRes: + config_dict: Dict[str, Scalar] = { + "a": 1.0, + "b": 3, + "c": True, + } # valid since both Ins/Res communicate over ConfigsRecord + + return GetPropertiesRes( + status=Status(code=Code(0), message=""), properties=config_dict + ) + + +@pytest.mark.parametrize( + "keep_input, validate_freed_fn", + [ + ( + False, + lambda x, x_copy, y: len(x.parameters.tensors) == 0 and x_copy == y, + ), # check tensors were freed + ( + True, + lambda x, x_copy, y: x == y, + ), + ], +) +def test_fitins_to_recordset_and_back( + keep_input: bool, validate_freed_fn: Callable[[FitIns, FitIns, FitIns], bool] +) -> None: + """Test conversion FitIns --> RecordSet --> FitIns.""" + fitins = _get_valid_fitins() + + fitins_copy = deepcopy(fitins) + + recordset = fitins_to_recordset(fitins, keep_input=keep_input) + + fitins_ = recordset_to_fitins(recordset, keep_input=keep_input) + + assert validate_freed_fn(fitins, fitins_copy, fitins_) + + +@pytest.mark.parametrize( + "keep_input, validate_freed_fn", + [ + ( + False, + lambda x, x_copy, y: len(x.parameters.tensors) == 0 and x_copy == y, + ), # check tensors were freed + ( + True, + lambda x, x_copy, y: x == y, + ), + ], +) +def test_fitres_to_recordset_and_back( + keep_input: bool, validate_freed_fn: Callable[[FitRes, FitRes, FitRes], bool] +) -> None: + """Test conversion FitRes --> RecordSet --> FitRes.""" + fitres = _get_valid_fitres() + + fitres_copy = deepcopy(fitres) + + recordset = fitres_to_recordset(fitres, keep_input=keep_input) + fitres_ = recordset_to_fitres(recordset, keep_input=keep_input) + + assert validate_freed_fn(fitres, fitres_copy, fitres_) + + +@pytest.mark.parametrize( + "keep_input, validate_freed_fn", + [ + ( + False, + lambda x, x_copy, y: len(x.parameters.tensors) == 0 and x_copy == y, + ), # check tensors were freed + ( + True, + lambda x, x_copy, y: x == y, + ), + ], +) +def test_evaluateins_to_recordset_and_back( + keep_input: bool, + validate_freed_fn: Callable[[EvaluateIns, EvaluateIns, EvaluateIns], bool], +) -> None: + """Test conversion EvaluateIns --> RecordSet --> EvaluateIns.""" + evaluateins = _get_valid_evaluateins() + + evaluateins_copy = deepcopy(evaluateins) + + recordset = evaluateins_to_recordset(evaluateins, keep_input=keep_input) + + evaluateins_ = recordset_to_evaluateins(recordset, keep_input=keep_input) + + assert validate_freed_fn(evaluateins, evaluateins_copy, evaluateins_) + + +def test_evaluateres_to_recordset_and_back() -> None: + """Test conversion EvaluateRes --> RecordSet --> EvaluateRes.""" + evaluateres = _get_valid_evaluateres() + + evaluateres_copy = deepcopy(evaluateres) + + recordset = evaluateres_to_recordset(evaluateres) + evaluateres_ = recordset_to_evaluateres(recordset) + + assert evaluateres_copy == evaluateres_ + + +def test_get_properties_ins_to_recordset_and_back() -> None: + """Test conversion GetPropertiesIns --> RecordSet --> GetPropertiesIns.""" + getproperties_ins = _get_valid_getpropertiesins() + + getproperties_ins_copy = deepcopy(getproperties_ins) + + recordset = getpropertiesins_to_recordset(getproperties_ins) + getproperties_ins_ = recordset_to_getpropertiesins(recordset) + + assert getproperties_ins_copy == getproperties_ins_ + + +def test_get_properties_res_to_recordset_and_back() -> None: + """Test conversion GetPropertiesRes --> RecordSet --> GetPropertiesRes.""" + getproperties_res = _get_valid_getpropertiesres() + + getproperties_res_copy = deepcopy(getproperties_res) + + recordset = getpropertiesres_to_recordset(getproperties_res) + getproperties_res_ = recordset_to_getpropertiesres(recordset) + + assert getproperties_res_copy == getproperties_res_ + + +def test_get_parameters_ins_to_recordset_and_back() -> None: + """Test conversion GetParametersIns --> RecordSet --> GetParametersIns.""" + getparameters_ins = _get_valid_getparametersins() + + getparameters_ins_copy = deepcopy(getparameters_ins) + + recordset = getparametersins_to_recordset(getparameters_ins) + getparameters_ins_ = recordset_to_getparametersins(recordset) + + assert getparameters_ins_copy == getparameters_ins_ + + +@pytest.mark.parametrize( + "keep_input, validate_freed_fn", + [ + ( + False, + lambda x, x_copy, y: len(x.parameters.tensors) == 0 and x_copy == y, + ), # check tensors were freed + ( + True, + lambda x, x_copy, y: x == y, + ), + ], +) +def test_get_parameters_res_to_recordset_and_back( + keep_input: bool, + validate_freed_fn: Callable[ + [GetParametersRes, GetParametersRes, GetParametersRes], bool + ], +) -> None: + """Test conversion GetParametersRes --> RecordSet --> GetParametersRes.""" + getparameteres_res = _get_valid_getparametersres() + + getparameters_res_copy = deepcopy(getparameteres_res) + + recordset = getparametersres_to_recordset(getparameteres_res, keep_input=keep_input) + getparameteres_res_ = recordset_to_getparametersres( + recordset, keep_input=keep_input + ) + + assert validate_freed_fn( + getparameteres_res, getparameters_res_copy, getparameteres_res_ + ) diff --git a/src/py/flwr/common/retry_invoker.py b/src/py/flwr/common/retry_invoker.py index a60fff57e7bf..5441e766983a 100644 --- a/src/py/flwr/common/retry_invoker.py +++ b/src/py/flwr/common/retry_invoker.py @@ -156,6 +156,7 @@ class RetryInvoker: >>> invoker.invoke(my_func, arg1, arg2, kw1=kwarg1) """ + # pylint: disable-next=too-many-arguments def __init__( self, wait_factory: Callable[[], Generator[float, None, None]], diff --git a/src/py/flwr/common/secure_aggregation/ndarrays_arithmetic.py b/src/py/flwr/common/secure_aggregation/ndarrays_arithmetic.py index 57afa56b7a08..e926a9531bea 100644 --- a/src/py/flwr/common/secure_aggregation/ndarrays_arithmetic.py +++ b/src/py/flwr/common/secure_aggregation/ndarrays_arithmetic.py @@ -15,7 +15,7 @@ """Utility functions for performing operations on Numpy NDArrays.""" -from typing import Any, List, Tuple +from typing import Any, List, Tuple, Union import numpy as np from numpy.typing import DTypeLike, NDArray @@ -68,14 +68,14 @@ def parameters_mod(parameters: List[NDArray[Any]], divisor: int) -> List[NDArray def parameters_multiply( - parameters: List[NDArray[Any]], multiplier: int + parameters: List[NDArray[Any]], multiplier: Union[int, float] ) -> List[NDArray[Any]]: - """Multiply parameters by an integer multiplier.""" + """Multiply parameters by an integer/float multiplier.""" return [parameters[idx] * multiplier for idx in range(len(parameters))] def parameters_divide( - parameters: List[NDArray[Any]], divisor: int + parameters: List[NDArray[Any]], divisor: Union[int, float] ) -> List[NDArray[Any]]: - """Divide weight by an integer divisor.""" + """Divide weight by an integer/float divisor.""" return [parameters[idx] / divisor for idx in range(len(parameters))] diff --git a/src/py/flwr/common/secure_aggregation/secaggplus_constants.py b/src/py/flwr/common/secure_aggregation/secaggplus_constants.py index 8dd21a6016f1..8a15908c13c5 100644 --- a/src/py/flwr/common/secure_aggregation/secaggplus_constants.py +++ b/src/py/flwr/common/secure_aggregation/secaggplus_constants.py @@ -15,30 +15,55 @@ """Constants for the SecAgg/SecAgg+ protocol.""" -# Names of stages -STAGE_SETUP = "setup" -STAGE_SHARE_KEYS = "share_keys" -STAGE_COLLECT_MASKED_INPUT = "collect_masked_input" -STAGE_UNMASK = "unmask" -STAGES = (STAGE_SETUP, STAGE_SHARE_KEYS, STAGE_COLLECT_MASKED_INPUT, STAGE_UNMASK) - -# All valid keys in received/replied `named_values` dictionaries -KEY_STAGE = "stage" -KEY_SAMPLE_NUMBER = "sample_num" -KEY_SECURE_ID = "secure_id" -KEY_SHARE_NUMBER = "share_num" -KEY_THRESHOLD = "threshold" -KEY_CLIPPING_RANGE = "clipping_range" -KEY_TARGET_RANGE = "target_range" -KEY_MOD_RANGE = "mod_range" -KEY_PUBLIC_KEY_1 = "pk1" -KEY_PUBLIC_KEY_2 = "pk2" -KEY_DESTINATION_LIST = "dsts" -KEY_CIPHERTEXT_LIST = "ctxts" -KEY_SOURCE_LIST = "srcs" -KEY_PARAMETERS = "params" -KEY_MASKED_PARAMETERS = "masked_params" -KEY_ACTIVE_SECURE_ID_LIST = "active_sids" -KEY_DEAD_SECURE_ID_LIST = "dead_sids" -KEY_SECURE_ID_LIST = "sids" -KEY_SHARE_LIST = "shares" +from __future__ import annotations + +RECORD_KEY_STATE = "secaggplus_state" +RECORD_KEY_CONFIGS = "secaggplus_configs" +RATIO_QUANTIZATION_RANGE = 1073741824 # 1 << 30 + + +class Stage: + """Stages for the SecAgg+ protocol.""" + + SETUP = "setup" + SHARE_KEYS = "share_keys" + COLLECT_MASKED_VECTORS = "collect_masked_vectors" + UNMASK = "unmask" + _stages = (SETUP, SHARE_KEYS, COLLECT_MASKED_VECTORS, UNMASK) + + @classmethod + def all(cls) -> tuple[str, str, str, str]: + """Return all stages.""" + return cls._stages + + def __new__(cls) -> Stage: + """Prevent instantiation.""" + raise TypeError(f"{cls.__name__} cannot be instantiated.") + + +class Key: + """Keys for the configs in the ConfigsRecord.""" + + STAGE = "stage" + SAMPLE_NUMBER = "sample_num" + SHARE_NUMBER = "share_num" + THRESHOLD = "threshold" + CLIPPING_RANGE = "clipping_range" + TARGET_RANGE = "target_range" + MOD_RANGE = "mod_range" + MAX_WEIGHT = "max_weight" + PUBLIC_KEY_1 = "pk1" + PUBLIC_KEY_2 = "pk2" + DESTINATION_LIST = "dsts" + CIPHERTEXT_LIST = "ctxts" + SOURCE_LIST = "srcs" + PARAMETERS = "params" + MASKED_PARAMETERS = "masked_params" + ACTIVE_NODE_ID_LIST = "active_nids" + DEAD_NODE_ID_LIST = "dead_nids" + NODE_ID_LIST = "nids" + SHARE_LIST = "shares" + + def __new__(cls) -> Key: + """Prevent instantiation.""" + raise TypeError(f"{cls.__name__} cannot be instantiated.") diff --git a/src/py/flwr/common/secure_aggregation/secaggplus_utils.py b/src/py/flwr/common/secure_aggregation/secaggplus_utils.py index def677e9d5d9..c373573477b9 100644 --- a/src/py/flwr/common/secure_aggregation/secaggplus_utils.py +++ b/src/py/flwr/common/secure_aggregation/secaggplus_utils.py @@ -23,16 +23,16 @@ def share_keys_plaintext_concat( - source: int, destination: int, b_share: bytes, sk_share: bytes + src_node_id: int, dst_node_id: int, b_share: bytes, sk_share: bytes ) -> bytes: """Combine arguments to bytes. Parameters ---------- - source : int - the secure ID of the source. - destination : int - the secure ID of the destination. + src_node_id : int + the node ID of the source. + dst_node_id : int + the node ID of the destination. b_share : bytes the private key share of the source sent to the destination. sk_share : bytes @@ -45,8 +45,8 @@ def share_keys_plaintext_concat( """ return b"".join( [ - int.to_bytes(source, 4, "little"), - int.to_bytes(destination, 4, "little"), + int.to_bytes(src_node_id, 8, "little", signed=True), + int.to_bytes(dst_node_id, 8, "little", signed=True), int.to_bytes(len(b_share), 4, "little"), b_share, sk_share, @@ -64,21 +64,21 @@ def share_keys_plaintext_separate(plaintext: bytes) -> Tuple[int, int, bytes, by Returns ------- - source : int - the secure ID of the source. - destination : int - the secure ID of the destination. + src_node_id : int + the node ID of the source. + dst_node_id : int + the node ID of the destination. b_share : bytes the private key share of the source sent to the destination. sk_share : bytes the secret key share of the source sent to the destination. """ src, dst, mark = ( - int.from_bytes(plaintext[:4], "little"), - int.from_bytes(plaintext[4:8], "little"), - int.from_bytes(plaintext[8:12], "little"), + int.from_bytes(plaintext[:8], "little", signed=True), + int.from_bytes(plaintext[8:16], "little", signed=True), + int.from_bytes(plaintext[16:20], "little"), ) - ret = (src, dst, plaintext[12 : 12 + mark], plaintext[12 + mark :]) + ret = (src, dst, plaintext[20 : 20 + mark], plaintext[20 + mark :]) return ret diff --git a/src/py/flwr/common/serde.py b/src/py/flwr/common/serde.py index c8c73e87e04a..6c7a077d2f9f 100644 --- a/src/py/flwr/common/serde.py +++ b/src/py/flwr/common/serde.py @@ -15,9 +15,24 @@ """ProtoBuf serialization and deserialization.""" -from typing import Any, Dict, List, MutableMapping, cast - -from flwr.proto.task_pb2 import Value +from typing import Any, Dict, List, MutableMapping, OrderedDict, Type, TypeVar, cast + +from google.protobuf.message import Message as GrpcMessage + +# pylint: disable=E0611 +from flwr.proto.error_pb2 import Error as ProtoError +from flwr.proto.node_pb2 import Node +from flwr.proto.recordset_pb2 import Array as ProtoArray +from flwr.proto.recordset_pb2 import BoolList, BytesList +from flwr.proto.recordset_pb2 import ConfigsRecord as ProtoConfigsRecord +from flwr.proto.recordset_pb2 import ConfigsRecordValue as ProtoConfigsRecordValue +from flwr.proto.recordset_pb2 import DoubleList +from flwr.proto.recordset_pb2 import MetricsRecord as ProtoMetricsRecord +from flwr.proto.recordset_pb2 import MetricsRecordValue as ProtoMetricsRecordValue +from flwr.proto.recordset_pb2 import ParametersRecord as ProtoParametersRecord +from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet +from flwr.proto.recordset_pb2 import Sint64List, StringList +from flwr.proto.task_pb2 import Task, TaskIns, TaskRes from flwr.proto.transport_pb2 import ( ClientMessage, Code, @@ -28,139 +43,10 @@ Status, ) -from . import typing - -# === ServerMessage message === - - -def server_message_to_proto(server_message: typing.ServerMessage) -> ServerMessage: - """Serialize `ServerMessage` to ProtoBuf.""" - if server_message.get_properties_ins is not None: - return ServerMessage( - get_properties_ins=get_properties_ins_to_proto( - server_message.get_properties_ins, - ) - ) - if server_message.get_parameters_ins is not None: - return ServerMessage( - get_parameters_ins=get_parameters_ins_to_proto( - server_message.get_parameters_ins, - ) - ) - if server_message.fit_ins is not None: - return ServerMessage( - fit_ins=fit_ins_to_proto( - server_message.fit_ins, - ) - ) - if server_message.evaluate_ins is not None: - return ServerMessage( - evaluate_ins=evaluate_ins_to_proto( - server_message.evaluate_ins, - ) - ) - raise Exception("No instruction set in ServerMessage, cannot serialize to ProtoBuf") - - -def server_message_from_proto( - server_message_proto: ServerMessage, -) -> typing.ServerMessage: - """Deserialize `ServerMessage` from ProtoBuf.""" - field = server_message_proto.WhichOneof("msg") - if field == "get_properties_ins": - return typing.ServerMessage( - get_properties_ins=get_properties_ins_from_proto( - server_message_proto.get_properties_ins, - ) - ) - if field == "get_parameters_ins": - return typing.ServerMessage( - get_parameters_ins=get_parameters_ins_from_proto( - server_message_proto.get_parameters_ins, - ) - ) - if field == "fit_ins": - return typing.ServerMessage( - fit_ins=fit_ins_from_proto( - server_message_proto.fit_ins, - ) - ) - if field == "evaluate_ins": - return typing.ServerMessage( - evaluate_ins=evaluate_ins_from_proto( - server_message_proto.evaluate_ins, - ) - ) - raise Exception( - "Unsupported instruction in ServerMessage, cannot deserialize from ProtoBuf" - ) - - -# === ClientMessage message === - - -def client_message_to_proto(client_message: typing.ClientMessage) -> ClientMessage: - """Serialize `ClientMessage` to ProtoBuf.""" - if client_message.get_properties_res is not None: - return ClientMessage( - get_properties_res=get_properties_res_to_proto( - client_message.get_properties_res, - ) - ) - if client_message.get_parameters_res is not None: - return ClientMessage( - get_parameters_res=get_parameters_res_to_proto( - client_message.get_parameters_res, - ) - ) - if client_message.fit_res is not None: - return ClientMessage( - fit_res=fit_res_to_proto( - client_message.fit_res, - ) - ) - if client_message.evaluate_res is not None: - return ClientMessage( - evaluate_res=evaluate_res_to_proto( - client_message.evaluate_res, - ) - ) - raise Exception("No instruction set in ClientMessage, cannot serialize to ProtoBuf") - - -def client_message_from_proto( - client_message_proto: ClientMessage, -) -> typing.ClientMessage: - """Deserialize `ClientMessage` from ProtoBuf.""" - field = client_message_proto.WhichOneof("msg") - if field == "get_properties_res": - return typing.ClientMessage( - get_properties_res=get_properties_res_from_proto( - client_message_proto.get_properties_res, - ) - ) - if field == "get_parameters_res": - return typing.ClientMessage( - get_parameters_res=get_parameters_res_from_proto( - client_message_proto.get_parameters_res, - ) - ) - if field == "fit_res": - return typing.ClientMessage( - fit_res=fit_res_from_proto( - client_message_proto.fit_res, - ) - ) - if field == "evaluate_res": - return typing.ClientMessage( - evaluate_res=evaluate_res_from_proto( - client_message_proto.evaluate_res, - ) - ) - raise Exception( - "Unsupported instruction in ClientMessage, cannot deserialize from ProtoBuf" - ) - +# pylint: enable=E0611 +from . import Array, ConfigsRecord, MetricsRecord, ParametersRecord, RecordSet, typing +from .message import Error, Message, Metadata +from .record.typeddict import TypedDict # === Parameters message === @@ -186,26 +72,9 @@ def reconnect_ins_to_proto(ins: typing.ReconnectIns) -> ServerMessage.ReconnectI return ServerMessage.ReconnectIns() -def reconnect_ins_from_proto(msg: ServerMessage.ReconnectIns) -> typing.ReconnectIns: - """Deserialize `ReconnectIns` from ProtoBuf.""" - return typing.ReconnectIns(seconds=msg.seconds) - - # === DisconnectRes message === -def disconnect_res_to_proto(res: typing.DisconnectRes) -> ClientMessage.DisconnectRes: - """Serialize `DisconnectRes` to ProtoBuf.""" - reason_proto = Reason.UNKNOWN - if res.reason == "RECONNECT": - reason_proto = Reason.RECONNECT - elif res.reason == "POWER_DISCONNECTED": - reason_proto = Reason.POWER_DISCONNECTED - elif res.reason == "WIFI_UNAVAILABLE": - reason_proto = Reason.WIFI_UNAVAILABLE - return ClientMessage.DisconnectRes(reason=reason_proto) - - def disconnect_res_from_proto(msg: ClientMessage.DisconnectRes) -> typing.DisconnectRes: """Deserialize `DisconnectRes` from ProtoBuf.""" if msg.reason == Reason.RECONNECT: @@ -474,7 +343,7 @@ def scalar_to_proto(scalar: typing.Scalar) -> Scalar: if isinstance(scalar, str): return Scalar(string=scalar) - raise Exception( + raise ValueError( f"Accepted types: {bool, bytes, float, int, str} (but not {type(scalar)})" ) @@ -486,86 +355,313 @@ def scalar_from_proto(scalar_msg: Scalar) -> typing.Scalar: return cast(typing.Scalar, scalar) -# === Value messages === +# === Record messages === -_python_type_to_field_name = { +_type_to_field = { float: "double", int: "sint64", bool: "bool", str: "string", bytes: "bytes", } +_list_type_to_class_and_field = { + float: (DoubleList, "double_list"), + int: (Sint64List, "sint64_list"), + bool: (BoolList, "bool_list"), + str: (StringList, "string_list"), + bytes: (BytesList, "bytes_list"), +} +T = TypeVar("T") -_python_list_type_to_message_and_field_name = { - float: (Value.DoubleList, "double_list"), - int: (Value.Sint64List, "sint64_list"), - bool: (Value.BoolList, "bool_list"), - str: (Value.StringList, "string_list"), - bytes: (Value.BytesList, "bytes_list"), -} +def _record_value_to_proto( + value: Any, allowed_types: List[type], proto_class: Type[T] +) -> T: + """Serialize `*RecordValue` to ProtoBuf. + + Note: `bool` MUST be put in the front of allowd_types if it exists. + """ + arg = {} + for t in allowed_types: + # Single element + # Note: `isinstance(False, int) == True`. + if isinstance(value, t): + arg[_type_to_field[t]] = value + return proto_class(**arg) + # List + if isinstance(value, list) and all(isinstance(item, t) for item in value): + list_class, field_name = _list_type_to_class_and_field[t] + arg[field_name] = list_class(vals=value) + return proto_class(**arg) + # Invalid types + raise TypeError( + f"The type of the following value is not allowed " + f"in '{proto_class.__name__}':\n{value}" + ) -def _check_value(value: typing.Value) -> None: - if isinstance(value, tuple(_python_type_to_field_name.keys())): - return - if isinstance(value, list): - if len(value) > 0 and isinstance( - value[0], tuple(_python_type_to_field_name.keys()) - ): - data_type = type(value[0]) - for element in value: - if isinstance(element, data_type): - continue - raise Exception( - f"Inconsistent type: the types of elements in the list must " - f"be the same (expected {data_type}, but got {type(element)})." - ) +def _record_value_from_proto(value_proto: GrpcMessage) -> Any: + """Deserialize `*RecordValue` from ProtoBuf.""" + value_field = cast(str, value_proto.WhichOneof("value")) + if value_field.endswith("list"): + value = list(getattr(value_proto, value_field).vals) else: - raise TypeError( - f"Accepted types: {bool, bytes, float, int, str} or " - f"list of these types." + value = getattr(value_proto, value_field) + return value + + +def _record_value_dict_to_proto( + value_dict: TypedDict[str, Any], + allowed_types: List[type], + value_proto_class: Type[T], +) -> Dict[str, T]: + """Serialize the record value dict to ProtoBuf. + + Note: `bool` MUST be put in the front of allowd_types if it exists. + """ + # Move bool to the front + if bool in allowed_types and allowed_types[0] != bool: + allowed_types.remove(bool) + allowed_types.insert(0, bool) + + def proto(_v: Any) -> T: + return _record_value_to_proto(_v, allowed_types, value_proto_class) + + return {k: proto(v) for k, v in value_dict.items()} + + +def _record_value_dict_from_proto( + value_dict_proto: MutableMapping[str, Any] +) -> Dict[str, Any]: + """Deserialize the record value dict from ProtoBuf.""" + return {k: _record_value_from_proto(v) for k, v in value_dict_proto.items()} + + +def array_to_proto(array: Array) -> ProtoArray: + """Serialize Array to ProtoBuf.""" + return ProtoArray(**vars(array)) + + +def array_from_proto(array_proto: ProtoArray) -> Array: + """Deserialize Array from ProtoBuf.""" + return Array( + dtype=array_proto.dtype, + shape=list(array_proto.shape), + stype=array_proto.stype, + data=array_proto.data, + ) + + +def parameters_record_to_proto(record: ParametersRecord) -> ProtoParametersRecord: + """Serialize ParametersRecord to ProtoBuf.""" + return ProtoParametersRecord( + data_keys=record.keys(), + data_values=map(array_to_proto, record.values()), + ) + + +def parameters_record_from_proto( + record_proto: ProtoParametersRecord, +) -> ParametersRecord: + """Deserialize ParametersRecord from ProtoBuf.""" + return ParametersRecord( + array_dict=OrderedDict( + zip(record_proto.data_keys, map(array_from_proto, record_proto.data_values)) + ), + keep_input=False, + ) + + +def metrics_record_to_proto(record: MetricsRecord) -> ProtoMetricsRecord: + """Serialize MetricsRecord to ProtoBuf.""" + return ProtoMetricsRecord( + data=_record_value_dict_to_proto(record, [float, int], ProtoMetricsRecordValue) + ) + + +def metrics_record_from_proto(record_proto: ProtoMetricsRecord) -> MetricsRecord: + """Deserialize MetricsRecord from ProtoBuf.""" + return MetricsRecord( + metrics_dict=cast( + Dict[str, typing.MetricsRecordValues], + _record_value_dict_from_proto(record_proto.data), + ), + keep_input=False, + ) + + +def configs_record_to_proto(record: ConfigsRecord) -> ProtoConfigsRecord: + """Serialize ConfigsRecord to ProtoBuf.""" + return ProtoConfigsRecord( + data=_record_value_dict_to_proto( + record, + [bool, int, float, str, bytes], + ProtoConfigsRecordValue, ) + ) -def value_to_proto(value: typing.Value) -> Value: - """Serialize `Value` to ProtoBuf.""" - _check_value(value) +def configs_record_from_proto(record_proto: ProtoConfigsRecord) -> ConfigsRecord: + """Deserialize ConfigsRecord from ProtoBuf.""" + return ConfigsRecord( + configs_dict=cast( + Dict[str, typing.ConfigsRecordValues], + _record_value_dict_from_proto(record_proto.data), + ), + keep_input=False, + ) - arg = {} - if isinstance(value, list): - msg_class, field_name = _python_list_type_to_message_and_field_name[ - type(value[0]) if len(value) > 0 else int - ] - arg[field_name] = msg_class(vals=value) - else: - arg[_python_type_to_field_name[type(value)]] = value - return Value(**arg) +# === Error message === + + +def error_to_proto(error: Error) -> ProtoError: + """Serialize Error to ProtoBuf.""" + reason = error.reason if error.reason else "" + return ProtoError(code=error.code, reason=reason) + + +def error_from_proto(error_proto: ProtoError) -> Error: + """Deserialize Error from ProtoBuf.""" + reason = error_proto.reason if len(error_proto.reason) > 0 else None + return Error(code=error_proto.code, reason=reason) + + +# === RecordSet message === + + +def recordset_to_proto(recordset: RecordSet) -> ProtoRecordSet: + """Serialize RecordSet to ProtoBuf.""" + return ProtoRecordSet( + parameters={ + k: parameters_record_to_proto(v) + for k, v in recordset.parameters_records.items() + }, + metrics={ + k: metrics_record_to_proto(v) for k, v in recordset.metrics_records.items() + }, + configs={ + k: configs_record_to_proto(v) for k, v in recordset.configs_records.items() + }, + ) + + +def recordset_from_proto(recordset_proto: ProtoRecordSet) -> RecordSet: + """Deserialize RecordSet from ProtoBuf.""" + return RecordSet( + parameters_records={ + k: parameters_record_from_proto(v) + for k, v in recordset_proto.parameters.items() + }, + metrics_records={ + k: metrics_record_from_proto(v) for k, v in recordset_proto.metrics.items() + }, + configs_records={ + k: configs_record_from_proto(v) for k, v in recordset_proto.configs.items() + }, + ) -def value_from_proto(value_msg: Value) -> typing.Value: - """Deserialize `Value` from ProtoBuf.""" - value_field = cast(str, value_msg.WhichOneof("value")) - if value_field.endswith("list"): - value = list(getattr(value_msg, value_field).vals) - else: - value = getattr(value_msg, value_field) - return cast(typing.Value, value) +# === Message === + + +def message_to_taskins(message: Message) -> TaskIns: + """Create a TaskIns from the Message.""" + md = message.metadata + return TaskIns( + group_id=md.group_id, + run_id=md.run_id, + task=Task( + producer=Node(node_id=0, anonymous=True), # Assume driver node + consumer=Node(node_id=md.dst_node_id, anonymous=False), + ttl=md.ttl, + ancestry=[md.reply_to_message] if md.reply_to_message != "" else [], + task_type=md.message_type, + recordset=( + recordset_to_proto(message.content) if message.has_content() else None + ), + error=error_to_proto(message.error) if message.has_error() else None, + ), + ) -# === Named Values === +def message_from_taskins(taskins: TaskIns) -> Message: + """Create a Message from the TaskIns.""" + # Retrieve the Metadata + metadata = Metadata( + run_id=taskins.run_id, + message_id=taskins.task_id, + src_node_id=taskins.task.producer.node_id, + dst_node_id=taskins.task.consumer.node_id, + reply_to_message=taskins.task.ancestry[0] if taskins.task.ancestry else "", + group_id=taskins.group_id, + ttl=taskins.task.ttl, + message_type=taskins.task.task_type, + ) -def named_values_to_proto( - named_values: Dict[str, typing.Value], -) -> Dict[str, Value]: - """Serialize named values to ProtoBuf.""" - return {name: value_to_proto(value) for name, value in named_values.items()} + # Construct Message + return Message( + metadata=metadata, + content=( + recordset_from_proto(taskins.task.recordset) + if taskins.task.HasField("recordset") + else None + ), + error=( + error_from_proto(taskins.task.error) + if taskins.task.HasField("error") + else None + ), + ) -def named_values_from_proto( - named_values_proto: MutableMapping[str, Value] -) -> Dict[str, typing.Value]: - """Deserialize named values from ProtoBuf.""" - return {name: value_from_proto(value) for name, value in named_values_proto.items()} +def message_to_taskres(message: Message) -> TaskRes: + """Create a TaskRes from the Message.""" + md = message.metadata + return TaskRes( + task_id="", # This will be generated by the server + group_id=md.group_id, + run_id=md.run_id, + task=Task( + producer=Node(node_id=md.src_node_id, anonymous=False), + consumer=Node(node_id=0, anonymous=True), # Assume driver node + ttl=md.ttl, + ancestry=[md.reply_to_message] if md.reply_to_message != "" else [], + task_type=md.message_type, + recordset=( + recordset_to_proto(message.content) if message.has_content() else None + ), + error=error_to_proto(message.error) if message.has_error() else None, + ), + ) + + +def message_from_taskres(taskres: TaskRes) -> Message: + """Create a Message from the TaskIns.""" + # Retrieve the MetaData + metadata = Metadata( + run_id=taskres.run_id, + message_id=taskres.task_id, + src_node_id=taskres.task.producer.node_id, + dst_node_id=taskres.task.consumer.node_id, + reply_to_message=taskres.task.ancestry[0] if taskres.task.ancestry else "", + group_id=taskres.group_id, + ttl=taskres.task.ttl, + message_type=taskres.task.task_type, + ) + + # Construct the Message + return Message( + metadata=metadata, + content=( + recordset_from_proto(taskres.task.recordset) + if taskres.task.HasField("recordset") + else None + ), + error=( + error_from_proto(taskres.task.error) + if taskres.task.HasField("error") + else None + ), + ) diff --git a/src/py/flwr/common/serde_test.py b/src/py/flwr/common/serde_test.py index ba07890f4658..8596e5d2f330 100644 --- a/src/py/flwr/common/serde_test.py +++ b/src/py/flwr/common/serde_test.py @@ -14,21 +14,42 @@ # ============================================================================== """(De-)serialization tests.""" +import random +import string +from typing import Any, Callable, Optional, OrderedDict, Type, TypeVar, Union, cast -from typing import Dict, Union, cast +import pytest -from flwr.common import typing +# pylint: disable=E0611 from flwr.proto import transport_pb2 as pb2 +from flwr.proto.recordset_pb2 import Array as ProtoArray +from flwr.proto.recordset_pb2 import ConfigsRecord as ProtoConfigsRecord +from flwr.proto.recordset_pb2 import MetricsRecord as ProtoMetricsRecord +from flwr.proto.recordset_pb2 import ParametersRecord as ProtoParametersRecord +from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet +# pylint: enable=E0611 +from . import Array, ConfigsRecord, MetricsRecord, ParametersRecord, RecordSet, typing +from .message import Error, Message, Metadata from .serde import ( - named_values_from_proto, - named_values_to_proto, + array_from_proto, + array_to_proto, + configs_record_from_proto, + configs_record_to_proto, + message_from_taskins, + message_from_taskres, + message_to_taskins, + message_to_taskres, + metrics_record_from_proto, + metrics_record_to_proto, + parameters_record_from_proto, + parameters_record_to_proto, + recordset_from_proto, + recordset_to_proto, scalar_from_proto, scalar_to_proto, status_from_proto, status_to_proto, - value_from_proto, - value_to_proto, ) @@ -50,8 +71,8 @@ def test_serialisation_deserialisation() -> None: def test_status_to_proto() -> None: """Test status message (de-)serialization.""" # Prepare - code_msg = pb2.Code.OK - status_msg = pb2.Status(code=code_msg, message="Success") + code_msg = pb2.Code.OK # pylint: disable=E1101 + status_msg = pb2.Status(code=code_msg, message="Success") # pylint: disable=E1101 code = typing.Code.OK status = typing.Status(code=code, message="Success") @@ -66,8 +87,8 @@ def test_status_to_proto() -> None: def test_status_from_proto() -> None: """Test status message (de-)serialization.""" # Prepare - code_msg = pb2.Code.OK - status_msg = pb2.Status(code=code_msg, message="Success") + code_msg = pb2.Code.OK # pylint: disable=E1101 + status_msg = pb2.Status(code=code_msg, message="Success") # pylint: disable=E1101 code = typing.Code.OK status = typing.Status(code=code, message="Success") @@ -79,81 +100,290 @@ def test_status_from_proto() -> None: assert actual_status == status -def test_value_serialization_deserialization() -> None: - """Test if values are identical after (de-)serialization.""" - # Prepare - values = [ - # boolean scalar and list - True, - [True, False, False, True], - # bytes scalar and list - b"test \x01\x02\x03 !@#$%^&*()", - [b"\x0a\x0b", b"\x0c\x0d\x0e", b"\x0f"], - # float scalar and list - 3.14, - [2.714, -0.012], - # integer scalar and list - 23, - [123456], - # string scalar and list - "abcdefghijklmnopqrstuvwxy", - ["456hgdhfd", "1234567890123456789012345678901", "I'm a string."], - # empty list - [], - ] - - for value in values: - # Execute - serialized = value_to_proto(cast(typing.Value, value)) - deserialized = value_from_proto(serialized) +T = TypeVar("T") - # Assert - if isinstance(value, list): - assert isinstance(deserialized, list) - assert len(value) == len(deserialized) - for elm1, elm2 in zip(value, deserialized): - assert elm1 == elm2 + +class RecordMaker: + """A record maker based on a seeded random number generator.""" + + def __init__(self, state: int = 42) -> None: + self.rng = random.Random(state) + + def randbytes(self, n: int) -> bytes: + """Create a bytes.""" + return self.rng.getrandbits(n * 8).to_bytes(n, "little") + + def get_str(self, length: Optional[int] = None) -> str: + """Create a string.""" + char_pool = ( + string.ascii_letters + string.digits + " !@#$%^&*()_-+=[]|;':,./<>?{}" + ) + if length is None: + length = self.rng.randint(1, 10) + return "".join(self.rng.choices(char_pool, k=length)) + + def get_value(self, dtype: Type[T]) -> T: + """Create a value of a given type.""" + ret: Any = None + if dtype == bool: + ret = self.rng.random() < 0.5 + elif dtype == str: + ret = self.get_str(self.rng.randint(10, 100)) + elif dtype == int: + ret = self.rng.randint(-1 << 30, 1 << 30) + elif dtype == float: + ret = (self.rng.random() - 0.5) * (2.0 ** self.rng.randint(0, 50)) + elif dtype == bytes: + ret = self.randbytes(self.rng.randint(10, 100)) else: - assert value == deserialized + raise NotImplementedError(f"Unsupported dtype: {dtype}") + return cast(T, ret) + def array(self) -> Array: + """Create a Array.""" + dtypes = ("float", "int") + stypes = ("torch", "tf", "numpy") + max_shape_size = 100 + max_shape_dim = 10 + min_max_bytes_size = (10, 1000) -def test_named_values_serialization_deserialization() -> None: - """Test if named values is identical after (de-)serialization.""" + dtype = self.rng.choice(dtypes) + shape = [ + self.rng.randint(1, max_shape_size) + for _ in range(self.rng.randint(1, max_shape_dim)) + ] + stype = self.rng.choice(stypes) + data = self.randbytes(self.rng.randint(*min_max_bytes_size)) + return Array(dtype=dtype, shape=shape, stype=stype, data=data) + + def parameters_record(self) -> ParametersRecord: + """Create a ParametersRecord.""" + num_arrays = self.rng.randint(1, 5) + arrays = OrderedDict( + [(self.get_str(), self.array()) for i in range(num_arrays)] + ) + return ParametersRecord(arrays, keep_input=False) + + def metrics_record(self) -> MetricsRecord: + """Create a MetricsRecord.""" + num_entries = self.rng.randint(1, 5) + types = (float, int) + return MetricsRecord( + metrics_dict={ + self.get_str(): self.get_value(self.rng.choice(types)) + for _ in range(num_entries) + }, + keep_input=False, + ) + + def configs_record(self) -> ConfigsRecord: + """Create a ConfigsRecord.""" + num_entries = self.rng.randint(1, 5) + types = (str, int, float, bytes, bool) + return ConfigsRecord( + configs_dict={ + self.get_str(): self.get_value(self.rng.choice(types)) + for _ in range(num_entries) + }, + keep_input=False, + ) + + def recordset( + self, + num_params_records: int, + num_metrics_records: int, + num_configs_records: int, + ) -> RecordSet: + """Create a RecordSet.""" + return RecordSet( + parameters_records={ + self.get_str(): self.parameters_record() + for _ in range(num_params_records) + }, + metrics_records={ + self.get_str(): self.metrics_record() + for _ in range(num_metrics_records) + }, + configs_records={ + self.get_str(): self.configs_record() + for _ in range(num_configs_records) + }, + ) + + def metadata(self) -> Metadata: + """Create a Metadata.""" + return Metadata( + run_id=self.rng.randint(0, 1 << 30), + message_id=self.get_str(64), + group_id=self.get_str(30), + src_node_id=self.rng.randint(0, 1 << 63), + dst_node_id=self.rng.randint(0, 1 << 63), + reply_to_message=self.get_str(64), + ttl=self.get_str(10), + message_type=self.get_str(10), + ) + + +def test_array_serialization_deserialization() -> None: + """Test serialization and deserialization of Array.""" # Prepare - values = [ - # boolean scalar and list - True, - [True, False, False, True], - # bytes scalar and list - b"test \x01\x02\x03 !@#$%^&*()", - [b"\x0a\x0b", b"\x0c\x0d\x0e", b"\x0f"], - # float scalar and list - 3.14, - [2.714, -0.012], - # integer scalar and list - 23, - [123456], - # string scalar and list - "abcdefghijklmnopqrstuvwxy", - ["456hgdhfd", "1234567890123456789012345678901", "I'm a string."], - # empty list - [], - ] - named_values = {f"value {i}": value for i, value in enumerate(values)} + maker = RecordMaker() + original = maker.array() # Execute - serialized = named_values_to_proto(cast(Dict[str, typing.Value], named_values)) - deserialized = named_values_from_proto(serialized) + proto = array_to_proto(original) + deserialized = array_from_proto(proto) # Assert - assert len(named_values) == len(deserialized) - for name in named_values: - expected = named_values[name] - actual = deserialized[name] - if isinstance(expected, list): - assert isinstance(actual, list) - assert len(expected) == len(actual) - for elm1, elm2 in zip(expected, actual): - assert elm1 == elm2 - else: - assert expected == actual + assert isinstance(proto, ProtoArray) + assert original == deserialized + + +def test_parameters_record_serialization_deserialization() -> None: + """Test serialization and deserialization of ParametersRecord.""" + # Prepare + maker = RecordMaker() + original = maker.parameters_record() + + # Execute + proto = parameters_record_to_proto(original) + deserialized = parameters_record_from_proto(proto) + + # Assert + assert isinstance(proto, ProtoParametersRecord) + assert original == deserialized + + +def test_metrics_record_serialization_deserialization() -> None: + """Test serialization and deserialization of MetricsRecord.""" + # Prepare + maker = RecordMaker() + original = maker.metrics_record() + + # Execute + proto = metrics_record_to_proto(original) + deserialized = metrics_record_from_proto(proto) + + # Assert + assert isinstance(proto, ProtoMetricsRecord) + assert original == deserialized + + +def test_configs_record_serialization_deserialization() -> None: + """Test serialization and deserialization of ConfigsRecord.""" + # Prepare + maker = RecordMaker() + original = maker.configs_record() + + # Execute + proto = configs_record_to_proto(original) + deserialized = configs_record_from_proto(proto) + + # Assert + assert isinstance(proto, ProtoConfigsRecord) + assert original == deserialized + + +def test_recordset_serialization_deserialization() -> None: + """Test serialization and deserialization of RecordSet.""" + # Prepare + maker = RecordMaker(state=0) + original = maker.recordset(2, 2, 1) + + # Execute + proto = recordset_to_proto(original) + deserialized = recordset_from_proto(proto) + + # Assert + assert isinstance(proto, ProtoRecordSet) + assert original == deserialized + + +@pytest.mark.parametrize( + "content_fn, error_fn", + [ + ( + lambda maker: maker.recordset(1, 1, 1), + None, + ), # check when only content is set + (None, lambda code: Error(code=code)), # check when only error is set + ], +) +def test_message_to_and_from_taskins( + content_fn: Callable[ + [ + RecordMaker, + ], + RecordSet, + ], + error_fn: Callable[[int], Error], +) -> None: + """Test Message to and from TaskIns.""" + # Prepare + + maker = RecordMaker(state=1) + metadata = maker.metadata() + # pylint: disable-next=protected-access + metadata._src_node_id = 0 # Assume driver node + + original = Message( + metadata=metadata, + content=None if content_fn is None else content_fn(maker), + error=None if error_fn is None else error_fn(0), + ) + + # Execute + taskins = message_to_taskins(original) + taskins.task_id = metadata.message_id + deserialized = message_from_taskins(taskins) + + # Assert + if original.has_content(): + assert original.content == deserialized.content + if original.has_error(): + assert original.error == deserialized.error + assert metadata == deserialized.metadata + + +@pytest.mark.parametrize( + "content_fn, error_fn", + [ + ( + lambda maker: maker.recordset(1, 1, 1), + None, + ), # check when only content is set + (None, lambda code: Error(code=code)), # check when only error is set + ], +) +def test_message_to_and_from_taskres( + content_fn: Callable[ + [ + RecordMaker, + ], + RecordSet, + ], + error_fn: Callable[[int], Error], +) -> None: + """Test Message to and from TaskRes.""" + # Prepare + maker = RecordMaker(state=2) + metadata = maker.metadata() + metadata.dst_node_id = 0 # Assume driver node + + original = Message( + metadata=metadata, + content=None if content_fn is None else content_fn(maker), + error=None if error_fn is None else error_fn(0), + ) + + # Execute + taskres = message_to_taskres(original) + taskres.task_id = metadata.message_id + deserialized = message_from_taskres(taskres) + + # Assert + if original.has_content(): + assert original.content == deserialized.content + if original.has_error(): + assert original.error == deserialized.error + assert metadata == deserialized.metadata diff --git a/src/py/flwr/common/telemetry.py b/src/py/flwr/common/telemetry.py index fed8b5a978bc..8eb594085d31 100644 --- a/src/py/flwr/common/telemetry.py +++ b/src/py/flwr/common/telemetry.py @@ -32,7 +32,7 @@ FLWR_TELEMETRY_ENABLED = os.getenv("FLWR_TELEMETRY_ENABLED", "1") FLWR_TELEMETRY_LOGGING = os.getenv("FLWR_TELEMETRY_LOGGING", "0") -TELEMETRY_EVENTS_URL = "https://telemetry.flower.dev/api/v1/event" +TELEMETRY_EVENTS_URL = "https://telemetry.flower.ai/api/v1/event" LOGGER_NAME = "flwr-telemetry" LOGGER_LEVEL = logging.DEBUG @@ -137,8 +137,8 @@ def _generate_next_value_(name: str, start: int, count: int, last_values: List[A RUN_FLEET_API_LEAVE = auto() # Driver API and Fleet API - RUN_SERVER_ENTER = auto() - RUN_SERVER_LEAVE = auto() + RUN_SUPERLINK_ENTER = auto() + RUN_SUPERLINK_LEAVE = auto() # Simulation START_SIMULATION_ENTER = auto() @@ -152,9 +152,13 @@ def _generate_next_value_(name: str, start: int, count: int, last_values: List[A START_DRIVER_ENTER = auto() START_DRIVER_LEAVE = auto() - # SuperNode: flower-client - RUN_CLIENT_ENTER = auto() - RUN_CLIENT_LEAVE = auto() + # flower-client-app + RUN_CLIENT_APP_ENTER = auto() + RUN_CLIENT_APP_LEAVE = auto() + + # flower-server-app + RUN_SERVER_APP_ENTER = auto() + RUN_SERVER_APP_LEAVE = auto() # Use the ThreadPoolExecutor with max_workers=1 to have a queue diff --git a/src/py/flwr/common/typing.py b/src/py/flwr/common/typing.py index 6c0266f5eec8..d6b2ec9b158c 100644 --- a/src/py/flwr/common/typing.py +++ b/src/py/flwr/common/typing.py @@ -45,6 +45,15 @@ List[str], ] +# Value types for common.MetricsRecord +MetricsScalar = Union[int, float] +MetricsScalarList = Union[List[int], List[float]] +MetricsRecordValues = Union[MetricsScalar, MetricsScalarList] +# Value types for common.ConfigsRecord +ConfigsScalar = Union[MetricsScalar, str, bytes, bool] +ConfigsScalarList = Union[MetricsScalarList, List[str], List[bytes], List[bool]] +ConfigsRecordValues = Union[ConfigsScalar, ConfigsScalarList] + Metrics = Dict[str, Scalar] MetricsAggregationFn = Callable[[List[Tuple[int, Metrics]]], Metrics] diff --git a/src/py/flwr/common/version.py b/src/py/flwr/common/version.py index 9545b4b68073..6808c66606b1 100644 --- a/src/py/flwr/common/version.py +++ b/src/py/flwr/common/version.py @@ -1,6 +1,5 @@ """Flower package version helper.""" - import importlib.metadata as importlib_metadata from typing import Tuple diff --git a/src/py/flwr/driver/app_test.py b/src/py/flwr/driver/app_test.py deleted file mode 100644 index 91b4fd30bc4b..000000000000 --- a/src/py/flwr/driver/app_test.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright 2022 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Flower Driver app tests.""" -# pylint: disable=no-self-use - - -import threading -import time -import unittest -from unittest.mock import MagicMock - -from flwr.driver.app import update_client_manager -from flwr.proto.driver_pb2 import CreateWorkloadResponse, GetNodesResponse -from flwr.proto.node_pb2 import Node -from flwr.server.client_manager import SimpleClientManager - - -class TestClientManagerWithDriver(unittest.TestCase): - """Tests for ClientManager. - - Considering multi-threading, all tests assume that the `update_client_manager()` - updates the ClientManager every 3 seconds. - """ - - def test_simple_client_manager_update(self) -> None: - """Tests if the node update works correctly.""" - # Prepare - expected_nodes = [Node(node_id=i, anonymous=False) for i in range(100)] - expected_updated_nodes = [ - Node(node_id=i, anonymous=False) for i in range(80, 120) - ] - driver = MagicMock() - driver.stub = "driver stub" - driver.create_workload.return_value = CreateWorkloadResponse(workload_id=1) - driver.get_nodes.return_value = GetNodesResponse(nodes=expected_nodes) - client_manager = SimpleClientManager() - lock = threading.Lock() - - # Execute - thread = threading.Thread( - target=update_client_manager, - args=( - driver, - client_manager, - lock, - ), - daemon=True, - ) - thread.start() - # Wait until all nodes are registered via `client_manager.sample()` - client_manager.sample(len(expected_nodes)) - # Retrieve all nodes in `client_manager` - node_ids = {proxy.node_id for proxy in client_manager.all().values()} - # Update the GetNodesResponse and wait until the `client_manager` is updated - driver.get_nodes.return_value = GetNodesResponse(nodes=expected_updated_nodes) - while True: - with lock: - if len(client_manager.all()) == len(expected_updated_nodes): - break - time.sleep(1.3) - # Retrieve all nodes in `client_manager` - updated_node_ids = {proxy.node_id for proxy in client_manager.all().values()} - # Simulate `driver.disconnect()` - driver.stub = None - - # Assert - driver.create_workload.assert_called_once() - assert node_ids == {node.node_id for node in expected_nodes} - assert updated_node_ids == {node.node_id for node in expected_updated_nodes} - - # Exit - thread.join() diff --git a/src/py/flwr/driver/driver.py b/src/py/flwr/driver/driver.py deleted file mode 100644 index f1a7c6663c11..000000000000 --- a/src/py/flwr/driver/driver.py +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright 2022 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Flower driver service client.""" - - -from typing import Iterable, List, Optional, Tuple - -from flwr.driver.grpc_driver import DEFAULT_SERVER_ADDRESS_DRIVER, GrpcDriver -from flwr.proto.driver_pb2 import ( - CreateWorkloadRequest, - GetNodesRequest, - PullTaskResRequest, - PushTaskInsRequest, -) -from flwr.proto.node_pb2 import Node -from flwr.proto.task_pb2 import TaskIns, TaskRes - - -class Driver: - """`Driver` class provides an interface to the Driver API. - - Parameters - ---------- - driver_service_address : Optional[str] - The IPv4 or IPv6 address of the Driver API server. - Defaults to `"[::]:9091"`. - certificates : bytes (default: None) - Tuple containing root certificate, server certificate, and private key - to start a secure SSL-enabled server. The tuple is expected to have - three bytes elements in the following order: - - * CA certificate. - * server certificate. - * server private key. - """ - - def __init__( - self, - driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER, - certificates: Optional[bytes] = None, - ) -> None: - self.addr = driver_service_address - self.certificates = certificates - self.grpc_driver: Optional[GrpcDriver] = None - self.workload_id: Optional[int] = None - self.node = Node(node_id=0, anonymous=True) - - def _get_grpc_driver_and_workload_id(self) -> Tuple[GrpcDriver, int]: - # Check if the GrpcDriver is initialized - if self.grpc_driver is None or self.workload_id is None: - # Connect and create workload - self.grpc_driver = GrpcDriver( - driver_service_address=self.addr, certificates=self.certificates - ) - self.grpc_driver.connect() - res = self.grpc_driver.create_workload(CreateWorkloadRequest()) - self.workload_id = res.workload_id - - return self.grpc_driver, self.workload_id - - def get_nodes(self) -> List[Node]: - """Get node IDs.""" - grpc_driver, workload_id = self._get_grpc_driver_and_workload_id() - - # Call GrpcDriver method - res = grpc_driver.get_nodes(GetNodesRequest(workload_id=workload_id)) - return list(res.nodes) - - def push_task_ins(self, task_ins_list: List[TaskIns]) -> List[str]: - """Schedule tasks.""" - grpc_driver, workload_id = self._get_grpc_driver_and_workload_id() - - # Set workload_id - for task_ins in task_ins_list: - task_ins.workload_id = workload_id - - # Call GrpcDriver method - res = grpc_driver.push_task_ins(PushTaskInsRequest(task_ins_list=task_ins_list)) - return list(res.task_ids) - - def pull_task_res(self, task_ids: Iterable[str]) -> List[TaskRes]: - """Get task results.""" - grpc_driver, _ = self._get_grpc_driver_and_workload_id() - - # Call GrpcDriver method - res = grpc_driver.pull_task_res( - PullTaskResRequest(node=self.node, task_ids=task_ids) - ) - return list(res.task_res_list) - - def __del__(self) -> None: - """Disconnect GrpcDriver if connected.""" - # Check if GrpcDriver is initialized - if self.grpc_driver is None: - return - - # Disconnect - self.grpc_driver.disconnect() diff --git a/src/py/flwr/driver/driver_client_proxy.py b/src/py/flwr/driver/driver_client_proxy.py deleted file mode 100644 index 6d60fc49159b..000000000000 --- a/src/py/flwr/driver/driver_client_proxy.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Flower ClientProxy implementation for Driver API.""" - - -import time -from typing import List, Optional, cast - -from flwr import common -from flwr.common import serde -from flwr.proto import driver_pb2, node_pb2, task_pb2, transport_pb2 -from flwr.server.client_proxy import ClientProxy - -from .grpc_driver import GrpcDriver - -SLEEP_TIME = 1 - - -class DriverClientProxy(ClientProxy): - """Flower client proxy which delegates work using the Driver API.""" - - def __init__( - self, node_id: int, driver: GrpcDriver, anonymous: bool, workload_id: int - ): - super().__init__(str(node_id)) - self.node_id = node_id - self.driver = driver - self.workload_id = workload_id - self.anonymous = anonymous - - def get_properties( - self, ins: common.GetPropertiesIns, timeout: Optional[float] - ) -> common.GetPropertiesRes: - """Return client's properties.""" - server_message_proto: transport_pb2.ServerMessage = ( - serde.server_message_to_proto( - server_message=common.ServerMessage(get_properties_ins=ins) - ) - ) - return cast( - common.GetPropertiesRes, - self._send_receive_msg(server_message_proto, timeout).get_properties_res, - ) - - def get_parameters( - self, ins: common.GetParametersIns, timeout: Optional[float] - ) -> common.GetParametersRes: - """Return the current local model parameters.""" - server_message_proto: transport_pb2.ServerMessage = ( - serde.server_message_to_proto( - server_message=common.ServerMessage(get_parameters_ins=ins) - ) - ) - return cast( - common.GetParametersRes, - self._send_receive_msg(server_message_proto, timeout).get_parameters_res, - ) - - def fit(self, ins: common.FitIns, timeout: Optional[float]) -> common.FitRes: - """Train model parameters on the locally held dataset.""" - server_message_proto: transport_pb2.ServerMessage = ( - serde.server_message_to_proto( - server_message=common.ServerMessage(fit_ins=ins) - ) - ) - return cast( - common.FitRes, - self._send_receive_msg(server_message_proto, timeout).fit_res, - ) - - def evaluate( - self, ins: common.EvaluateIns, timeout: Optional[float] - ) -> common.EvaluateRes: - """Evaluate model parameters on the locally held dataset.""" - server_message_proto: transport_pb2.ServerMessage = ( - serde.server_message_to_proto( - server_message=common.ServerMessage(evaluate_ins=ins) - ) - ) - return cast( - common.EvaluateRes, - self._send_receive_msg(server_message_proto, timeout).evaluate_res, - ) - - def reconnect( - self, ins: common.ReconnectIns, timeout: Optional[float] - ) -> common.DisconnectRes: - """Disconnect and (optionally) reconnect later.""" - return common.DisconnectRes(reason="") # Nothing to do here (yet) - - def _send_receive_msg( - self, server_message: transport_pb2.ServerMessage, timeout: Optional[float] - ) -> transport_pb2.ClientMessage: - task_ins = task_pb2.TaskIns( - task_id="", - group_id="", - workload_id=self.workload_id, - task=task_pb2.Task( - producer=node_pb2.Node( - node_id=0, - anonymous=True, - ), - consumer=node_pb2.Node( - node_id=self.node_id, - anonymous=self.anonymous, - ), - legacy_server_message=server_message, - ), - ) - push_task_ins_req = driver_pb2.PushTaskInsRequest(task_ins_list=[task_ins]) - - # Send TaskIns to Driver API - push_task_ins_res = self.driver.push_task_ins(req=push_task_ins_req) - - if len(push_task_ins_res.task_ids) != 1: - raise ValueError("Unexpected number of task_ids") - - task_id = push_task_ins_res.task_ids[0] - if task_id == "": - raise ValueError(f"Failed to schedule task for node {self.node_id}") - - if timeout: - start_time = time.time() - - while True: - pull_task_res_req = driver_pb2.PullTaskResRequest( - node=node_pb2.Node(node_id=0, anonymous=True), - task_ids=[task_id], - ) - - # Ask Driver API for TaskRes - pull_task_res_res = self.driver.pull_task_res(req=pull_task_res_req) - - task_res_list: List[task_pb2.TaskRes] = list( - pull_task_res_res.task_res_list - ) - if len(task_res_list) == 1: - task_res = task_res_list[0] - return serde.client_message_from_proto( # type: ignore - task_res.task.legacy_client_message - ) - - if timeout is not None and time.time() > start_time + timeout: - raise RuntimeError("Timeout reached") - time.sleep(SLEEP_TIME) diff --git a/src/py/flwr/driver/driver_client_proxy_test.py b/src/py/flwr/driver/driver_client_proxy_test.py deleted file mode 100644 index 82b5b46d7810..000000000000 --- a/src/py/flwr/driver/driver_client_proxy_test.py +++ /dev/null @@ -1,185 +0,0 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""DriverClientProxy tests.""" - - -import unittest -from unittest.mock import MagicMock - -import numpy as np - -import flwr -from flwr.common.typing import Config, GetParametersIns -from flwr.driver.driver_client_proxy import DriverClientProxy -from flwr.proto import driver_pb2, node_pb2, task_pb2 -from flwr.proto.transport_pb2 import ClientMessage, Parameters, Scalar - -MESSAGE_PARAMETERS = Parameters(tensors=[b"abc"], tensor_type="np") - -CLIENT_PROPERTIES = {"tensor_type": Scalar(string="numpy.ndarray")} - - -class DriverClientProxyTestCase(unittest.TestCase): - """Tests for DriverClientProxy.""" - - def setUp(self) -> None: - """Set up mocks for tests.""" - self.driver = MagicMock() - self.driver.get_nodes.return_value = driver_pb2.GetNodesResponse( - nodes=[node_pb2.Node(node_id=1, anonymous=False)] - ) - - def test_get_properties(self) -> None: - """Test positive case.""" - # Prepare - self.driver.push_task_ins.return_value = driver_pb2.PushTaskInsResponse( - task_ids=["19341fd7-62e1-4eb4-beb4-9876d3acda32"] - ) - self.driver.pull_task_res.return_value = driver_pb2.PullTaskResResponse( - task_res_list=[ - task_pb2.TaskRes( - task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", - group_id="", - workload_id=0, - task=task_pb2.Task( - legacy_client_message=ClientMessage( - get_properties_res=ClientMessage.GetPropertiesRes( - properties=CLIENT_PROPERTIES - ) - ) - ), - ) - ] - ) - client = DriverClientProxy( - node_id=1, driver=self.driver, anonymous=True, workload_id=0 - ) - request_properties: Config = {"tensor_type": "str"} - ins: flwr.common.GetPropertiesIns = flwr.common.GetPropertiesIns( - config=request_properties - ) - - # Execute - value: flwr.common.GetPropertiesRes = client.get_properties(ins, timeout=None) - - # Assert - assert value.properties["tensor_type"] == "numpy.ndarray" - - def test_get_parameters(self) -> None: - """Test positive case.""" - # Prepare - self.driver.push_task_ins.return_value = driver_pb2.PushTaskInsResponse( - task_ids=["19341fd7-62e1-4eb4-beb4-9876d3acda32"] - ) - self.driver.pull_task_res.return_value = driver_pb2.PullTaskResResponse( - task_res_list=[ - task_pb2.TaskRes( - task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", - group_id="", - workload_id=0, - task=task_pb2.Task( - legacy_client_message=ClientMessage( - get_parameters_res=ClientMessage.GetParametersRes( - parameters=MESSAGE_PARAMETERS, - ) - ) - ), - ) - ] - ) - client = DriverClientProxy( - node_id=1, driver=self.driver, anonymous=True, workload_id=0 - ) - get_parameters_ins = GetParametersIns(config={}) - - # Execute - value: flwr.common.GetParametersRes = client.get_parameters( - ins=get_parameters_ins, timeout=None - ) - - # Assert - assert value.parameters.tensors[0] == b"abc" - - def test_fit(self) -> None: - """Test positive case.""" - # Prepare - self.driver.push_task_ins.return_value = driver_pb2.PushTaskInsResponse( - task_ids=["19341fd7-62e1-4eb4-beb4-9876d3acda32"] - ) - self.driver.pull_task_res.return_value = driver_pb2.PullTaskResResponse( - task_res_list=[ - task_pb2.TaskRes( - task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", - group_id="", - workload_id=0, - task=task_pb2.Task( - legacy_client_message=ClientMessage( - fit_res=ClientMessage.FitRes( - parameters=MESSAGE_PARAMETERS, - num_examples=10, - ) - ) - ), - ) - ] - ) - client = DriverClientProxy( - node_id=1, driver=self.driver, anonymous=True, workload_id=0 - ) - parameters = flwr.common.ndarrays_to_parameters([np.ones((2, 2))]) - ins: flwr.common.FitIns = flwr.common.FitIns(parameters, {}) - - # Execute - fit_res = client.fit(ins=ins, timeout=None) - - # Assert - assert fit_res.parameters.tensor_type == "np" - assert fit_res.parameters.tensors[0] == b"abc" - assert fit_res.num_examples == 10 - - def test_evaluate(self) -> None: - """Test positive case.""" - # Prepare - self.driver.push_task_ins.return_value = driver_pb2.PushTaskInsResponse( - task_ids=["19341fd7-62e1-4eb4-beb4-9876d3acda32"] - ) - self.driver.pull_task_res.return_value = driver_pb2.PullTaskResResponse( - task_res_list=[ - task_pb2.TaskRes( - task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", - group_id="", - workload_id=0, - task=task_pb2.Task( - legacy_client_message=ClientMessage( - evaluate_res=ClientMessage.EvaluateRes( - loss=0.0, num_examples=0 - ) - ) - ), - ) - ] - ) - client = DriverClientProxy( - node_id=1, driver=self.driver, anonymous=True, workload_id=0 - ) - parameters = flwr.common.Parameters(tensors=[], tensor_type="np") - evaluate_ins: flwr.common.EvaluateIns = flwr.common.EvaluateIns(parameters, {}) - - # Execute - evaluate_res = client.evaluate(evaluate_ins, timeout=None) - - # Assert - assert 0.0 == evaluate_res.loss - assert 0 == evaluate_res.num_examples diff --git a/src/py/flwr/driver/driver_test.py b/src/py/flwr/driver/driver_test.py deleted file mode 100644 index 820018788a8f..000000000000 --- a/src/py/flwr/driver/driver_test.py +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright 2022 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for driver SDK.""" - - -import unittest -from unittest.mock import Mock, patch - -from flwr.driver.driver import Driver -from flwr.proto.driver_pb2 import ( - GetNodesRequest, - PullTaskResRequest, - PushTaskInsRequest, -) -from flwr.proto.task_pb2 import Task, TaskIns, TaskRes - - -class TestDriver(unittest.TestCase): - """Tests for `Driver` class.""" - - def setUp(self) -> None: - """Initialize mock GrpcDriver and Driver instance before each test.""" - mock_response = Mock() - mock_response.workload_id = 61016 - self.mock_grpc_driver = Mock() - self.mock_grpc_driver.create_workload.return_value = mock_response - self.patcher = patch( - "flwr.driver.driver.GrpcDriver", return_value=self.mock_grpc_driver - ) - self.patcher.start() - self.driver = Driver() - - def tearDown(self) -> None: - """Cleanup after each test.""" - self.patcher.stop() - - def test_check_and_init_grpc_driver_already_initialized(self) -> None: - """Test that GrpcDriver doesn't initialize if workload is created.""" - # Prepare - self.driver.grpc_driver = self.mock_grpc_driver - self.driver.workload_id = 61016 - - # Execute - # pylint: disable-next=protected-access - self.driver._get_grpc_driver_and_workload_id() - - # Assert - self.mock_grpc_driver.connect.assert_not_called() - - def test_check_and_init_grpc_driver_needs_initialization(self) -> None: - """Test GrpcDriver initialization when workload is not created.""" - # Execute - # pylint: disable-next=protected-access - self.driver._get_grpc_driver_and_workload_id() - - # Assert - self.mock_grpc_driver.connect.assert_called_once() - self.assertEqual(self.driver.workload_id, 61016) - - def test_get_nodes(self) -> None: - """Test retrieval of nodes.""" - # Prepare - mock_response = Mock() - mock_response.nodes = [Mock(), Mock()] - self.mock_grpc_driver.get_nodes.return_value = mock_response - - # Execute - nodes = self.driver.get_nodes() - args, kwargs = self.mock_grpc_driver.get_nodes.call_args - - # Assert - self.mock_grpc_driver.connect.assert_called_once() - self.assertEqual(len(args), 1) - self.assertEqual(len(kwargs), 0) - self.assertIsInstance(args[0], GetNodesRequest) - self.assertEqual(args[0].workload_id, 61016) - self.assertEqual(nodes, mock_response.nodes) - - def test_push_task_ins(self) -> None: - """Test pushing task instructions.""" - # Prepare - mock_response = Mock() - mock_response.task_ids = ["id1", "id2"] - self.mock_grpc_driver.push_task_ins.return_value = mock_response - task_ins_list = [TaskIns(), TaskIns()] - - # Execute - task_ids = self.driver.push_task_ins(task_ins_list) - args, kwargs = self.mock_grpc_driver.push_task_ins.call_args - - # Assert - self.mock_grpc_driver.connect.assert_called_once() - self.assertEqual(len(args), 1) - self.assertEqual(len(kwargs), 0) - self.assertIsInstance(args[0], PushTaskInsRequest) - self.assertEqual(task_ids, mock_response.task_ids) - for task_ins in args[0].task_ins_list: - self.assertEqual(task_ins.workload_id, 61016) - - def test_pull_task_res_with_given_task_ids(self) -> None: - """Test pulling task results with specific task IDs.""" - # Prepare - mock_response = Mock() - mock_response.task_res_list = [ - TaskRes(task=Task(ancestry=["id2"])), - TaskRes(task=Task(ancestry=["id3"])), - ] - self.mock_grpc_driver.pull_task_res.return_value = mock_response - task_ids = ["id1", "id2", "id3"] - - # Execute - task_res_list = self.driver.pull_task_res(task_ids) - args, kwargs = self.mock_grpc_driver.pull_task_res.call_args - - # Assert - self.mock_grpc_driver.connect.assert_called_once() - self.assertEqual(len(args), 1) - self.assertEqual(len(kwargs), 0) - self.assertIsInstance(args[0], PullTaskResRequest) - self.assertEqual(args[0].task_ids, task_ids) - self.assertEqual(task_res_list, mock_response.task_res_list) - - def test_del_with_initialized_driver(self) -> None: - """Test cleanup behavior when Driver is initialized.""" - # Prepare - # pylint: disable-next=protected-access - self.driver._get_grpc_driver_and_workload_id() - - # Execute - self.driver.__del__() - - # Assert - self.mock_grpc_driver.disconnect.assert_called_once() - - def test_del_with_uninitialized_driver(self) -> None: - """Test cleanup behavior when Driver is not initialized.""" - # Execute - self.driver.__del__() - - # Assert - self.mock_grpc_driver.disconnect.assert_not_called() diff --git a/src/py/flwr/proto/driver_pb2.py b/src/py/flwr/proto/driver_pb2.py index c138507e03e9..fe9c33da0fa9 100644 --- a/src/py/flwr/proto/driver_pb2.py +++ b/src/py/flwr/proto/driver_pb2.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: flwr/proto/driver.proto +# Protobuf Python Version: 4.25.0 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -16,94 +16,29 @@ from flwr.proto import task_pb2 as flwr_dot_proto_dot_task__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/driver.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\"\x17\n\x15\x43reateWorkloadRequest\"-\n\x16\x43reateWorkloadResponse\x12\x13\n\x0bworkload_id\x18\x01 \x01(\x12\"&\n\x0fGetNodesRequest\x12\x13\n\x0bworkload_id\x18\x01 \x01(\x12\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes2\xd0\x02\n\x06\x44river\x12Y\n\x0e\x43reateWorkload\x12!.flwr.proto.CreateWorkloadRequest\x1a\".flwr.proto.CreateWorkloadResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/driver.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\"\x12\n\x10\x43reateRunRequest\"#\n\x11\x43reateRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes2\xc1\x02\n\x06\x44river\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x62\x06proto3') - - -_CREATEWORKLOADREQUEST = DESCRIPTOR.message_types_by_name['CreateWorkloadRequest'] -_CREATEWORKLOADRESPONSE = DESCRIPTOR.message_types_by_name['CreateWorkloadResponse'] -_GETNODESREQUEST = DESCRIPTOR.message_types_by_name['GetNodesRequest'] -_GETNODESRESPONSE = DESCRIPTOR.message_types_by_name['GetNodesResponse'] -_PUSHTASKINSREQUEST = DESCRIPTOR.message_types_by_name['PushTaskInsRequest'] -_PUSHTASKINSRESPONSE = DESCRIPTOR.message_types_by_name['PushTaskInsResponse'] -_PULLTASKRESREQUEST = DESCRIPTOR.message_types_by_name['PullTaskResRequest'] -_PULLTASKRESRESPONSE = DESCRIPTOR.message_types_by_name['PullTaskResResponse'] -CreateWorkloadRequest = _reflection.GeneratedProtocolMessageType('CreateWorkloadRequest', (_message.Message,), { - 'DESCRIPTOR' : _CREATEWORKLOADREQUEST, - '__module__' : 'flwr.proto.driver_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.CreateWorkloadRequest) - }) -_sym_db.RegisterMessage(CreateWorkloadRequest) - -CreateWorkloadResponse = _reflection.GeneratedProtocolMessageType('CreateWorkloadResponse', (_message.Message,), { - 'DESCRIPTOR' : _CREATEWORKLOADRESPONSE, - '__module__' : 'flwr.proto.driver_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.CreateWorkloadResponse) - }) -_sym_db.RegisterMessage(CreateWorkloadResponse) - -GetNodesRequest = _reflection.GeneratedProtocolMessageType('GetNodesRequest', (_message.Message,), { - 'DESCRIPTOR' : _GETNODESREQUEST, - '__module__' : 'flwr.proto.driver_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.GetNodesRequest) - }) -_sym_db.RegisterMessage(GetNodesRequest) - -GetNodesResponse = _reflection.GeneratedProtocolMessageType('GetNodesResponse', (_message.Message,), { - 'DESCRIPTOR' : _GETNODESRESPONSE, - '__module__' : 'flwr.proto.driver_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.GetNodesResponse) - }) -_sym_db.RegisterMessage(GetNodesResponse) - -PushTaskInsRequest = _reflection.GeneratedProtocolMessageType('PushTaskInsRequest', (_message.Message,), { - 'DESCRIPTOR' : _PUSHTASKINSREQUEST, - '__module__' : 'flwr.proto.driver_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.PushTaskInsRequest) - }) -_sym_db.RegisterMessage(PushTaskInsRequest) - -PushTaskInsResponse = _reflection.GeneratedProtocolMessageType('PushTaskInsResponse', (_message.Message,), { - 'DESCRIPTOR' : _PUSHTASKINSRESPONSE, - '__module__' : 'flwr.proto.driver_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.PushTaskInsResponse) - }) -_sym_db.RegisterMessage(PushTaskInsResponse) - -PullTaskResRequest = _reflection.GeneratedProtocolMessageType('PullTaskResRequest', (_message.Message,), { - 'DESCRIPTOR' : _PULLTASKRESREQUEST, - '__module__' : 'flwr.proto.driver_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.PullTaskResRequest) - }) -_sym_db.RegisterMessage(PullTaskResRequest) - -PullTaskResResponse = _reflection.GeneratedProtocolMessageType('PullTaskResResponse', (_message.Message,), { - 'DESCRIPTOR' : _PULLTASKRESRESPONSE, - '__module__' : 'flwr.proto.driver_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.PullTaskResResponse) - }) -_sym_db.RegisterMessage(PullTaskResResponse) - -_DRIVER = DESCRIPTOR.services_by_name['Driver'] +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.driver_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _CREATEWORKLOADREQUEST._serialized_start=85 - _CREATEWORKLOADREQUEST._serialized_end=108 - _CREATEWORKLOADRESPONSE._serialized_start=110 - _CREATEWORKLOADRESPONSE._serialized_end=155 - _GETNODESREQUEST._serialized_start=157 - _GETNODESREQUEST._serialized_end=195 - _GETNODESRESPONSE._serialized_start=197 - _GETNODESRESPONSE._serialized_end=248 - _PUSHTASKINSREQUEST._serialized_start=250 - _PUSHTASKINSREQUEST._serialized_end=314 - _PUSHTASKINSRESPONSE._serialized_start=316 - _PUSHTASKINSRESPONSE._serialized_end=355 - _PULLTASKRESREQUEST._serialized_start=357 - _PULLTASKRESREQUEST._serialized_end=427 - _PULLTASKRESRESPONSE._serialized_start=429 - _PULLTASKRESRESPONSE._serialized_end=494 - _DRIVER._serialized_start=497 - _DRIVER._serialized_end=833 + _globals['_CREATERUNREQUEST']._serialized_start=85 + _globals['_CREATERUNREQUEST']._serialized_end=103 + _globals['_CREATERUNRESPONSE']._serialized_start=105 + _globals['_CREATERUNRESPONSE']._serialized_end=140 + _globals['_GETNODESREQUEST']._serialized_start=142 + _globals['_GETNODESREQUEST']._serialized_end=175 + _globals['_GETNODESRESPONSE']._serialized_start=177 + _globals['_GETNODESRESPONSE']._serialized_end=228 + _globals['_PUSHTASKINSREQUEST']._serialized_start=230 + _globals['_PUSHTASKINSREQUEST']._serialized_end=294 + _globals['_PUSHTASKINSRESPONSE']._serialized_start=296 + _globals['_PUSHTASKINSRESPONSE']._serialized_end=335 + _globals['_PULLTASKRESREQUEST']._serialized_start=337 + _globals['_PULLTASKRESREQUEST']._serialized_end=407 + _globals['_PULLTASKRESRESPONSE']._serialized_start=409 + _globals['_PULLTASKRESRESPONSE']._serialized_end=474 + _globals['_DRIVER']._serialized_start=477 + _globals['_DRIVER']._serialized_end=798 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/driver_pb2.pyi b/src/py/flwr/proto/driver_pb2.pyi index 8b940972cb6d..8dc254a55e8c 100644 --- a/src/py/flwr/proto/driver_pb2.pyi +++ b/src/py/flwr/proto/driver_pb2.pyi @@ -13,34 +13,34 @@ import typing_extensions DESCRIPTOR: google.protobuf.descriptor.FileDescriptor -class CreateWorkloadRequest(google.protobuf.message.Message): - """CreateWorkload""" +class CreateRunRequest(google.protobuf.message.Message): + """CreateRun""" DESCRIPTOR: google.protobuf.descriptor.Descriptor def __init__(self, ) -> None: ... -global___CreateWorkloadRequest = CreateWorkloadRequest +global___CreateRunRequest = CreateRunRequest -class CreateWorkloadResponse(google.protobuf.message.Message): +class CreateRunResponse(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor - WORKLOAD_ID_FIELD_NUMBER: builtins.int - workload_id: builtins.int + RUN_ID_FIELD_NUMBER: builtins.int + run_id: builtins.int def __init__(self, *, - workload_id: builtins.int = ..., + run_id: builtins.int = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["workload_id",b"workload_id"]) -> None: ... -global___CreateWorkloadResponse = CreateWorkloadResponse + def ClearField(self, field_name: typing_extensions.Literal["run_id",b"run_id"]) -> None: ... +global___CreateRunResponse = CreateRunResponse class GetNodesRequest(google.protobuf.message.Message): """GetNodes messages""" DESCRIPTOR: google.protobuf.descriptor.Descriptor - WORKLOAD_ID_FIELD_NUMBER: builtins.int - workload_id: builtins.int + RUN_ID_FIELD_NUMBER: builtins.int + run_id: builtins.int def __init__(self, *, - workload_id: builtins.int = ..., + run_id: builtins.int = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["workload_id",b"workload_id"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["run_id",b"run_id"]) -> None: ... global___GetNodesRequest = GetNodesRequest class GetNodesResponse(google.protobuf.message.Message): diff --git a/src/py/flwr/proto/driver_pb2_grpc.py b/src/py/flwr/proto/driver_pb2_grpc.py index ea33b843d945..ac6815023ebd 100644 --- a/src/py/flwr/proto/driver_pb2_grpc.py +++ b/src/py/flwr/proto/driver_pb2_grpc.py @@ -14,10 +14,10 @@ def __init__(self, channel): Args: channel: A grpc.Channel. """ - self.CreateWorkload = channel.unary_unary( - '/flwr.proto.Driver/CreateWorkload', - request_serializer=flwr_dot_proto_dot_driver__pb2.CreateWorkloadRequest.SerializeToString, - response_deserializer=flwr_dot_proto_dot_driver__pb2.CreateWorkloadResponse.FromString, + self.CreateRun = channel.unary_unary( + '/flwr.proto.Driver/CreateRun', + request_serializer=flwr_dot_proto_dot_driver__pb2.CreateRunRequest.SerializeToString, + response_deserializer=flwr_dot_proto_dot_driver__pb2.CreateRunResponse.FromString, ) self.GetNodes = channel.unary_unary( '/flwr.proto.Driver/GetNodes', @@ -39,8 +39,8 @@ def __init__(self, channel): class DriverServicer(object): """Missing associated documentation comment in .proto file.""" - def CreateWorkload(self, request, context): - """Request workload_id + def CreateRun(self, request, context): + """Request run_id """ context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details('Method not implemented!') @@ -70,10 +70,10 @@ def PullTaskRes(self, request, context): def add_DriverServicer_to_server(servicer, server): rpc_method_handlers = { - 'CreateWorkload': grpc.unary_unary_rpc_method_handler( - servicer.CreateWorkload, - request_deserializer=flwr_dot_proto_dot_driver__pb2.CreateWorkloadRequest.FromString, - response_serializer=flwr_dot_proto_dot_driver__pb2.CreateWorkloadResponse.SerializeToString, + 'CreateRun': grpc.unary_unary_rpc_method_handler( + servicer.CreateRun, + request_deserializer=flwr_dot_proto_dot_driver__pb2.CreateRunRequest.FromString, + response_serializer=flwr_dot_proto_dot_driver__pb2.CreateRunResponse.SerializeToString, ), 'GetNodes': grpc.unary_unary_rpc_method_handler( servicer.GetNodes, @@ -101,7 +101,7 @@ class Driver(object): """Missing associated documentation comment in .proto file.""" @staticmethod - def CreateWorkload(request, + def CreateRun(request, target, options=(), channel_credentials=None, @@ -111,9 +111,9 @@ def CreateWorkload(request, wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_unary(request, target, '/flwr.proto.Driver/CreateWorkload', - flwr_dot_proto_dot_driver__pb2.CreateWorkloadRequest.SerializeToString, - flwr_dot_proto_dot_driver__pb2.CreateWorkloadResponse.FromString, + return grpc.experimental.unary_unary(request, target, '/flwr.proto.Driver/CreateRun', + flwr_dot_proto_dot_driver__pb2.CreateRunRequest.SerializeToString, + flwr_dot_proto_dot_driver__pb2.CreateRunResponse.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/src/py/flwr/proto/driver_pb2_grpc.pyi b/src/py/flwr/proto/driver_pb2_grpc.pyi index 1b10d71e943d..43cf45f39b25 100644 --- a/src/py/flwr/proto/driver_pb2_grpc.pyi +++ b/src/py/flwr/proto/driver_pb2_grpc.pyi @@ -8,10 +8,10 @@ import grpc class DriverStub: def __init__(self, channel: grpc.Channel) -> None: ... - CreateWorkload: grpc.UnaryUnaryMultiCallable[ - flwr.proto.driver_pb2.CreateWorkloadRequest, - flwr.proto.driver_pb2.CreateWorkloadResponse] - """Request workload_id""" + CreateRun: grpc.UnaryUnaryMultiCallable[ + flwr.proto.driver_pb2.CreateRunRequest, + flwr.proto.driver_pb2.CreateRunResponse] + """Request run_id""" GetNodes: grpc.UnaryUnaryMultiCallable[ flwr.proto.driver_pb2.GetNodesRequest, @@ -31,11 +31,11 @@ class DriverStub: class DriverServicer(metaclass=abc.ABCMeta): @abc.abstractmethod - def CreateWorkload(self, - request: flwr.proto.driver_pb2.CreateWorkloadRequest, + def CreateRun(self, + request: flwr.proto.driver_pb2.CreateRunRequest, context: grpc.ServicerContext, - ) -> flwr.proto.driver_pb2.CreateWorkloadResponse: - """Request workload_id""" + ) -> flwr.proto.driver_pb2.CreateRunResponse: + """Request run_id""" pass @abc.abstractmethod diff --git a/src/py/flwr/proto/error_pb2.py b/src/py/flwr/proto/error_pb2.py new file mode 100644 index 000000000000..41721ae08804 --- /dev/null +++ b/src/py/flwr/proto/error_pb2.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: flwr/proto/error.proto +# Protobuf Python Version: 4.25.0 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/error.proto\x12\nflwr.proto\"%\n\x05\x45rror\x12\x0c\n\x04\x63ode\x18\x01 \x01(\x12\x12\x0e\n\x06reason\x18\x02 \x01(\tb\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.error_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + _globals['_ERROR']._serialized_start=38 + _globals['_ERROR']._serialized_end=75 +# @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/error_pb2.pyi b/src/py/flwr/proto/error_pb2.pyi new file mode 100644 index 000000000000..1811e5aa0ca8 --- /dev/null +++ b/src/py/flwr/proto/error_pb2.pyi @@ -0,0 +1,25 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +""" +import builtins +import google.protobuf.descriptor +import google.protobuf.message +import typing +import typing_extensions + +DESCRIPTOR: google.protobuf.descriptor.FileDescriptor + +class Error(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + CODE_FIELD_NUMBER: builtins.int + REASON_FIELD_NUMBER: builtins.int + code: builtins.int + reason: typing.Text + def __init__(self, + *, + code: builtins.int = ..., + reason: typing.Text = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["code",b"code","reason",b"reason"]) -> None: ... +global___Error = Error diff --git a/src/py/flwr/proto/error_pb2_grpc.py b/src/py/flwr/proto/error_pb2_grpc.py new file mode 100644 index 000000000000..2daafffebfc8 --- /dev/null +++ b/src/py/flwr/proto/error_pb2_grpc.py @@ -0,0 +1,4 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + diff --git a/src/py/flwr/proto/error_pb2_grpc.pyi b/src/py/flwr/proto/error_pb2_grpc.pyi new file mode 100644 index 000000000000..f3a5a087ef5d --- /dev/null +++ b/src/py/flwr/proto/error_pb2_grpc.pyi @@ -0,0 +1,4 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +""" diff --git a/src/py/flwr/proto/fleet_pb2.py b/src/py/flwr/proto/fleet_pb2.py index e86a53e2139e..e8443c296f0c 100644 --- a/src/py/flwr/proto/fleet_pb2.py +++ b/src/py/flwr/proto/fleet_pb2.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: flwr/proto/fleet.proto +# Protobuf Python Version: 4.25.0 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -18,115 +18,33 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/fleet.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\"\x13\n\x11\x43reateNodeRequest\"4\n\x12\x43reateNodeResponse\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"3\n\x11\x44\x65leteNodeRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"\x14\n\x12\x44\x65leteNodeResponse\"F\n\x12PullTaskInsRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"k\n\x13PullTaskInsResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12*\n\rtask_ins_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"@\n\x12PushTaskResRequest\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes\"\xae\x01\n\x13PushTaskResResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12=\n\x07results\x18\x02 \x03(\x0b\x32,.flwr.proto.PushTaskResResponse.ResultsEntry\x1a.\n\x0cResultsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\r:\x02\x38\x01\"\x1e\n\tReconnect\x12\x11\n\treconnect\x18\x01 \x01(\x04\x32\xc9\x02\n\x05\x46leet\x12M\n\nCreateNode\x12\x1d.flwr.proto.CreateNodeRequest\x1a\x1e.flwr.proto.CreateNodeResponse\"\x00\x12M\n\nDeleteNode\x12\x1d.flwr.proto.DeleteNodeRequest\x1a\x1e.flwr.proto.DeleteNodeResponse\"\x00\x12P\n\x0bPullTaskIns\x12\x1e.flwr.proto.PullTaskInsRequest\x1a\x1f.flwr.proto.PullTaskInsResponse\"\x00\x12P\n\x0bPushTaskRes\x12\x1e.flwr.proto.PushTaskResRequest\x1a\x1f.flwr.proto.PushTaskResResponse\"\x00\x62\x06proto3') - - -_CREATENODEREQUEST = DESCRIPTOR.message_types_by_name['CreateNodeRequest'] -_CREATENODERESPONSE = DESCRIPTOR.message_types_by_name['CreateNodeResponse'] -_DELETENODEREQUEST = DESCRIPTOR.message_types_by_name['DeleteNodeRequest'] -_DELETENODERESPONSE = DESCRIPTOR.message_types_by_name['DeleteNodeResponse'] -_PULLTASKINSREQUEST = DESCRIPTOR.message_types_by_name['PullTaskInsRequest'] -_PULLTASKINSRESPONSE = DESCRIPTOR.message_types_by_name['PullTaskInsResponse'] -_PUSHTASKRESREQUEST = DESCRIPTOR.message_types_by_name['PushTaskResRequest'] -_PUSHTASKRESRESPONSE = DESCRIPTOR.message_types_by_name['PushTaskResResponse'] -_PUSHTASKRESRESPONSE_RESULTSENTRY = _PUSHTASKRESRESPONSE.nested_types_by_name['ResultsEntry'] -_RECONNECT = DESCRIPTOR.message_types_by_name['Reconnect'] -CreateNodeRequest = _reflection.GeneratedProtocolMessageType('CreateNodeRequest', (_message.Message,), { - 'DESCRIPTOR' : _CREATENODEREQUEST, - '__module__' : 'flwr.proto.fleet_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.CreateNodeRequest) - }) -_sym_db.RegisterMessage(CreateNodeRequest) - -CreateNodeResponse = _reflection.GeneratedProtocolMessageType('CreateNodeResponse', (_message.Message,), { - 'DESCRIPTOR' : _CREATENODERESPONSE, - '__module__' : 'flwr.proto.fleet_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.CreateNodeResponse) - }) -_sym_db.RegisterMessage(CreateNodeResponse) - -DeleteNodeRequest = _reflection.GeneratedProtocolMessageType('DeleteNodeRequest', (_message.Message,), { - 'DESCRIPTOR' : _DELETENODEREQUEST, - '__module__' : 'flwr.proto.fleet_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.DeleteNodeRequest) - }) -_sym_db.RegisterMessage(DeleteNodeRequest) - -DeleteNodeResponse = _reflection.GeneratedProtocolMessageType('DeleteNodeResponse', (_message.Message,), { - 'DESCRIPTOR' : _DELETENODERESPONSE, - '__module__' : 'flwr.proto.fleet_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.DeleteNodeResponse) - }) -_sym_db.RegisterMessage(DeleteNodeResponse) - -PullTaskInsRequest = _reflection.GeneratedProtocolMessageType('PullTaskInsRequest', (_message.Message,), { - 'DESCRIPTOR' : _PULLTASKINSREQUEST, - '__module__' : 'flwr.proto.fleet_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.PullTaskInsRequest) - }) -_sym_db.RegisterMessage(PullTaskInsRequest) - -PullTaskInsResponse = _reflection.GeneratedProtocolMessageType('PullTaskInsResponse', (_message.Message,), { - 'DESCRIPTOR' : _PULLTASKINSRESPONSE, - '__module__' : 'flwr.proto.fleet_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.PullTaskInsResponse) - }) -_sym_db.RegisterMessage(PullTaskInsResponse) - -PushTaskResRequest = _reflection.GeneratedProtocolMessageType('PushTaskResRequest', (_message.Message,), { - 'DESCRIPTOR' : _PUSHTASKRESREQUEST, - '__module__' : 'flwr.proto.fleet_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.PushTaskResRequest) - }) -_sym_db.RegisterMessage(PushTaskResRequest) - -PushTaskResResponse = _reflection.GeneratedProtocolMessageType('PushTaskResResponse', (_message.Message,), { - - 'ResultsEntry' : _reflection.GeneratedProtocolMessageType('ResultsEntry', (_message.Message,), { - 'DESCRIPTOR' : _PUSHTASKRESRESPONSE_RESULTSENTRY, - '__module__' : 'flwr.proto.fleet_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.PushTaskResResponse.ResultsEntry) - }) - , - 'DESCRIPTOR' : _PUSHTASKRESRESPONSE, - '__module__' : 'flwr.proto.fleet_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.PushTaskResResponse) - }) -_sym_db.RegisterMessage(PushTaskResResponse) -_sym_db.RegisterMessage(PushTaskResResponse.ResultsEntry) - -Reconnect = _reflection.GeneratedProtocolMessageType('Reconnect', (_message.Message,), { - 'DESCRIPTOR' : _RECONNECT, - '__module__' : 'flwr.proto.fleet_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.Reconnect) - }) -_sym_db.RegisterMessage(Reconnect) - -_FLEET = DESCRIPTOR.services_by_name['Fleet'] +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.fleet_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _PUSHTASKRESRESPONSE_RESULTSENTRY._options = None - _PUSHTASKRESRESPONSE_RESULTSENTRY._serialized_options = b'8\001' - _CREATENODEREQUEST._serialized_start=84 - _CREATENODEREQUEST._serialized_end=103 - _CREATENODERESPONSE._serialized_start=105 - _CREATENODERESPONSE._serialized_end=157 - _DELETENODEREQUEST._serialized_start=159 - _DELETENODEREQUEST._serialized_end=210 - _DELETENODERESPONSE._serialized_start=212 - _DELETENODERESPONSE._serialized_end=232 - _PULLTASKINSREQUEST._serialized_start=234 - _PULLTASKINSREQUEST._serialized_end=304 - _PULLTASKINSRESPONSE._serialized_start=306 - _PULLTASKINSRESPONSE._serialized_end=413 - _PUSHTASKRESREQUEST._serialized_start=415 - _PUSHTASKRESREQUEST._serialized_end=479 - _PUSHTASKRESRESPONSE._serialized_start=482 - _PUSHTASKRESRESPONSE._serialized_end=656 - _PUSHTASKRESRESPONSE_RESULTSENTRY._serialized_start=610 - _PUSHTASKRESRESPONSE_RESULTSENTRY._serialized_end=656 - _RECONNECT._serialized_start=658 - _RECONNECT._serialized_end=688 - _FLEET._serialized_start=691 - _FLEET._serialized_end=1020 + _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._options = None + _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_options = b'8\001' + _globals['_CREATENODEREQUEST']._serialized_start=84 + _globals['_CREATENODEREQUEST']._serialized_end=103 + _globals['_CREATENODERESPONSE']._serialized_start=105 + _globals['_CREATENODERESPONSE']._serialized_end=157 + _globals['_DELETENODEREQUEST']._serialized_start=159 + _globals['_DELETENODEREQUEST']._serialized_end=210 + _globals['_DELETENODERESPONSE']._serialized_start=212 + _globals['_DELETENODERESPONSE']._serialized_end=232 + _globals['_PULLTASKINSREQUEST']._serialized_start=234 + _globals['_PULLTASKINSREQUEST']._serialized_end=304 + _globals['_PULLTASKINSRESPONSE']._serialized_start=306 + _globals['_PULLTASKINSRESPONSE']._serialized_end=413 + _globals['_PUSHTASKRESREQUEST']._serialized_start=415 + _globals['_PUSHTASKRESREQUEST']._serialized_end=479 + _globals['_PUSHTASKRESRESPONSE']._serialized_start=482 + _globals['_PUSHTASKRESRESPONSE']._serialized_end=656 + _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_start=610 + _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_end=656 + _globals['_RECONNECT']._serialized_start=658 + _globals['_RECONNECT']._serialized_end=688 + _globals['_FLEET']._serialized_start=691 + _globals['_FLEET']._serialized_end=1020 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/node_pb2.py b/src/py/flwr/proto/node_pb2.py index 9d91900d8f53..b300f2c562c2 100644 --- a/src/py/flwr/proto/node_pb2.py +++ b/src/py/flwr/proto/node_pb2.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: flwr/proto/node.proto +# Protobuf Python Version: 4.25.0 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -16,19 +16,11 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/node.proto\x12\nflwr.proto\"*\n\x04Node\x12\x0f\n\x07node_id\x18\x01 \x01(\x12\x12\x11\n\tanonymous\x18\x02 \x01(\x08\x62\x06proto3') - - -_NODE = DESCRIPTOR.message_types_by_name['Node'] -Node = _reflection.GeneratedProtocolMessageType('Node', (_message.Message,), { - 'DESCRIPTOR' : _NODE, - '__module__' : 'flwr.proto.node_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.Node) - }) -_sym_db.RegisterMessage(Node) - +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.node_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _NODE._serialized_start=37 - _NODE._serialized_end=79 + _globals['_NODE']._serialized_start=37 + _globals['_NODE']._serialized_end=79 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/recordset_pb2.py b/src/py/flwr/proto/recordset_pb2.py new file mode 100644 index 000000000000..f7f74d72182b --- /dev/null +++ b/src/py/flwr/proto/recordset_pb2.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: flwr/proto/recordset.proto +# Protobuf Python Version: 4.25.0 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1a\x66lwr/proto/recordset.proto\x12\nflwr.proto\"\x1a\n\nDoubleList\x12\x0c\n\x04vals\x18\x01 \x03(\x01\"\x1a\n\nSint64List\x12\x0c\n\x04vals\x18\x01 \x03(\x12\"\x18\n\x08\x42oolList\x12\x0c\n\x04vals\x18\x01 \x03(\x08\"\x1a\n\nStringList\x12\x0c\n\x04vals\x18\x01 \x03(\t\"\x19\n\tBytesList\x12\x0c\n\x04vals\x18\x01 \x03(\x0c\"B\n\x05\x41rray\x12\r\n\x05\x64type\x18\x01 \x01(\t\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05stype\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\"\x9f\x01\n\x12MetricsRecordValue\x12\x10\n\x06\x64ouble\x18\x01 \x01(\x01H\x00\x12\x10\n\x06sint64\x18\x02 \x01(\x12H\x00\x12-\n\x0b\x64ouble_list\x18\x15 \x01(\x0b\x32\x16.flwr.proto.DoubleListH\x00\x12-\n\x0bsint64_list\x18\x16 \x01(\x0b\x32\x16.flwr.proto.Sint64ListH\x00\x42\x07\n\x05value\"\xd9\x02\n\x12\x43onfigsRecordValue\x12\x10\n\x06\x64ouble\x18\x01 \x01(\x01H\x00\x12\x10\n\x06sint64\x18\x02 \x01(\x12H\x00\x12\x0e\n\x04\x62ool\x18\x03 \x01(\x08H\x00\x12\x10\n\x06string\x18\x04 \x01(\tH\x00\x12\x0f\n\x05\x62ytes\x18\x05 \x01(\x0cH\x00\x12-\n\x0b\x64ouble_list\x18\x15 \x01(\x0b\x32\x16.flwr.proto.DoubleListH\x00\x12-\n\x0bsint64_list\x18\x16 \x01(\x0b\x32\x16.flwr.proto.Sint64ListH\x00\x12)\n\tbool_list\x18\x17 \x01(\x0b\x32\x14.flwr.proto.BoolListH\x00\x12-\n\x0bstring_list\x18\x18 \x01(\x0b\x32\x16.flwr.proto.StringListH\x00\x12+\n\nbytes_list\x18\x19 \x01(\x0b\x32\x15.flwr.proto.BytesListH\x00\x42\x07\n\x05value\"M\n\x10ParametersRecord\x12\x11\n\tdata_keys\x18\x01 \x03(\t\x12&\n\x0b\x64\x61ta_values\x18\x02 \x03(\x0b\x32\x11.flwr.proto.Array\"\x8f\x01\n\rMetricsRecord\x12\x31\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32#.flwr.proto.MetricsRecord.DataEntry\x1aK\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12-\n\x05value\x18\x02 \x01(\x0b\x32\x1e.flwr.proto.MetricsRecordValue:\x02\x38\x01\"\x8f\x01\n\rConfigsRecord\x12\x31\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32#.flwr.proto.ConfigsRecord.DataEntry\x1aK\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12-\n\x05value\x18\x02 \x01(\x0b\x32\x1e.flwr.proto.ConfigsRecordValue:\x02\x38\x01\"\x97\x03\n\tRecordSet\x12\x39\n\nparameters\x18\x01 \x03(\x0b\x32%.flwr.proto.RecordSet.ParametersEntry\x12\x33\n\x07metrics\x18\x02 \x03(\x0b\x32\".flwr.proto.RecordSet.MetricsEntry\x12\x33\n\x07\x63onfigs\x18\x03 \x03(\x0b\x32\".flwr.proto.RecordSet.ConfigsEntry\x1aO\n\x0fParametersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12+\n\x05value\x18\x02 \x01(\x0b\x32\x1c.flwr.proto.ParametersRecord:\x02\x38\x01\x1aI\n\x0cMetricsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.flwr.proto.MetricsRecord:\x02\x38\x01\x1aI\n\x0c\x43onfigsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.flwr.proto.ConfigsRecord:\x02\x38\x01\x62\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.recordset_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + _globals['_METRICSRECORD_DATAENTRY']._options = None + _globals['_METRICSRECORD_DATAENTRY']._serialized_options = b'8\001' + _globals['_CONFIGSRECORD_DATAENTRY']._options = None + _globals['_CONFIGSRECORD_DATAENTRY']._serialized_options = b'8\001' + _globals['_RECORDSET_PARAMETERSENTRY']._options = None + _globals['_RECORDSET_PARAMETERSENTRY']._serialized_options = b'8\001' + _globals['_RECORDSET_METRICSENTRY']._options = None + _globals['_RECORDSET_METRICSENTRY']._serialized_options = b'8\001' + _globals['_RECORDSET_CONFIGSENTRY']._options = None + _globals['_RECORDSET_CONFIGSENTRY']._serialized_options = b'8\001' + _globals['_DOUBLELIST']._serialized_start=42 + _globals['_DOUBLELIST']._serialized_end=68 + _globals['_SINT64LIST']._serialized_start=70 + _globals['_SINT64LIST']._serialized_end=96 + _globals['_BOOLLIST']._serialized_start=98 + _globals['_BOOLLIST']._serialized_end=122 + _globals['_STRINGLIST']._serialized_start=124 + _globals['_STRINGLIST']._serialized_end=150 + _globals['_BYTESLIST']._serialized_start=152 + _globals['_BYTESLIST']._serialized_end=177 + _globals['_ARRAY']._serialized_start=179 + _globals['_ARRAY']._serialized_end=245 + _globals['_METRICSRECORDVALUE']._serialized_start=248 + _globals['_METRICSRECORDVALUE']._serialized_end=407 + _globals['_CONFIGSRECORDVALUE']._serialized_start=410 + _globals['_CONFIGSRECORDVALUE']._serialized_end=755 + _globals['_PARAMETERSRECORD']._serialized_start=757 + _globals['_PARAMETERSRECORD']._serialized_end=834 + _globals['_METRICSRECORD']._serialized_start=837 + _globals['_METRICSRECORD']._serialized_end=980 + _globals['_METRICSRECORD_DATAENTRY']._serialized_start=905 + _globals['_METRICSRECORD_DATAENTRY']._serialized_end=980 + _globals['_CONFIGSRECORD']._serialized_start=983 + _globals['_CONFIGSRECORD']._serialized_end=1126 + _globals['_CONFIGSRECORD_DATAENTRY']._serialized_start=1051 + _globals['_CONFIGSRECORD_DATAENTRY']._serialized_end=1126 + _globals['_RECORDSET']._serialized_start=1129 + _globals['_RECORDSET']._serialized_end=1536 + _globals['_RECORDSET_PARAMETERSENTRY']._serialized_start=1307 + _globals['_RECORDSET_PARAMETERSENTRY']._serialized_end=1386 + _globals['_RECORDSET_METRICSENTRY']._serialized_start=1388 + _globals['_RECORDSET_METRICSENTRY']._serialized_end=1461 + _globals['_RECORDSET_CONFIGSENTRY']._serialized_start=1463 + _globals['_RECORDSET_CONFIGSENTRY']._serialized_end=1536 +# @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/recordset_pb2.pyi b/src/py/flwr/proto/recordset_pb2.pyi new file mode 100644 index 000000000000..86244697129c --- /dev/null +++ b/src/py/flwr/proto/recordset_pb2.pyi @@ -0,0 +1,305 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +""" +import builtins +import google.protobuf.descriptor +import google.protobuf.internal.containers +import google.protobuf.message +import typing +import typing_extensions + +DESCRIPTOR: google.protobuf.descriptor.FileDescriptor + +class DoubleList(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + VALS_FIELD_NUMBER: builtins.int + @property + def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.float]: ... + def __init__(self, + *, + vals: typing.Optional[typing.Iterable[builtins.float]] = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... +global___DoubleList = DoubleList + +class Sint64List(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + VALS_FIELD_NUMBER: builtins.int + @property + def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + def __init__(self, + *, + vals: typing.Optional[typing.Iterable[builtins.int]] = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... +global___Sint64List = Sint64List + +class BoolList(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + VALS_FIELD_NUMBER: builtins.int + @property + def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bool]: ... + def __init__(self, + *, + vals: typing.Optional[typing.Iterable[builtins.bool]] = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... +global___BoolList = BoolList + +class StringList(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + VALS_FIELD_NUMBER: builtins.int + @property + def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text]: ... + def __init__(self, + *, + vals: typing.Optional[typing.Iterable[typing.Text]] = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... +global___StringList = StringList + +class BytesList(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + VALS_FIELD_NUMBER: builtins.int + @property + def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ... + def __init__(self, + *, + vals: typing.Optional[typing.Iterable[builtins.bytes]] = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... +global___BytesList = BytesList + +class Array(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + DTYPE_FIELD_NUMBER: builtins.int + SHAPE_FIELD_NUMBER: builtins.int + STYPE_FIELD_NUMBER: builtins.int + DATA_FIELD_NUMBER: builtins.int + dtype: typing.Text + @property + def shape(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + stype: typing.Text + data: builtins.bytes + def __init__(self, + *, + dtype: typing.Text = ..., + shape: typing.Optional[typing.Iterable[builtins.int]] = ..., + stype: typing.Text = ..., + data: builtins.bytes = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["data",b"data","dtype",b"dtype","shape",b"shape","stype",b"stype"]) -> None: ... +global___Array = Array + +class MetricsRecordValue(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + DOUBLE_FIELD_NUMBER: builtins.int + SINT64_FIELD_NUMBER: builtins.int + DOUBLE_LIST_FIELD_NUMBER: builtins.int + SINT64_LIST_FIELD_NUMBER: builtins.int + double: builtins.float + """Single element""" + + sint64: builtins.int + @property + def double_list(self) -> global___DoubleList: + """List types""" + pass + @property + def sint64_list(self) -> global___Sint64List: ... + def __init__(self, + *, + double: builtins.float = ..., + sint64: builtins.int = ..., + double_list: typing.Optional[global___DoubleList] = ..., + sint64_list: typing.Optional[global___Sint64List] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["double",b"double","double_list",b"double_list","sint64",b"sint64","sint64_list",b"sint64_list","value",b"value"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["double",b"double","double_list",b"double_list","sint64",b"sint64","sint64_list",b"sint64_list","value",b"value"]) -> None: ... + def WhichOneof(self, oneof_group: typing_extensions.Literal["value",b"value"]) -> typing.Optional[typing_extensions.Literal["double","sint64","double_list","sint64_list"]]: ... +global___MetricsRecordValue = MetricsRecordValue + +class ConfigsRecordValue(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + DOUBLE_FIELD_NUMBER: builtins.int + SINT64_FIELD_NUMBER: builtins.int + BOOL_FIELD_NUMBER: builtins.int + STRING_FIELD_NUMBER: builtins.int + BYTES_FIELD_NUMBER: builtins.int + DOUBLE_LIST_FIELD_NUMBER: builtins.int + SINT64_LIST_FIELD_NUMBER: builtins.int + BOOL_LIST_FIELD_NUMBER: builtins.int + STRING_LIST_FIELD_NUMBER: builtins.int + BYTES_LIST_FIELD_NUMBER: builtins.int + double: builtins.float + """Single element""" + + sint64: builtins.int + bool: builtins.bool + string: typing.Text + bytes: builtins.bytes + @property + def double_list(self) -> global___DoubleList: + """List types""" + pass + @property + def sint64_list(self) -> global___Sint64List: ... + @property + def bool_list(self) -> global___BoolList: ... + @property + def string_list(self) -> global___StringList: ... + @property + def bytes_list(self) -> global___BytesList: ... + def __init__(self, + *, + double: builtins.float = ..., + sint64: builtins.int = ..., + bool: builtins.bool = ..., + string: typing.Text = ..., + bytes: builtins.bytes = ..., + double_list: typing.Optional[global___DoubleList] = ..., + sint64_list: typing.Optional[global___Sint64List] = ..., + bool_list: typing.Optional[global___BoolList] = ..., + string_list: typing.Optional[global___StringList] = ..., + bytes_list: typing.Optional[global___BytesList] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["bool",b"bool","bool_list",b"bool_list","bytes",b"bytes","bytes_list",b"bytes_list","double",b"double","double_list",b"double_list","sint64",b"sint64","sint64_list",b"sint64_list","string",b"string","string_list",b"string_list","value",b"value"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["bool",b"bool","bool_list",b"bool_list","bytes",b"bytes","bytes_list",b"bytes_list","double",b"double","double_list",b"double_list","sint64",b"sint64","sint64_list",b"sint64_list","string",b"string","string_list",b"string_list","value",b"value"]) -> None: ... + def WhichOneof(self, oneof_group: typing_extensions.Literal["value",b"value"]) -> typing.Optional[typing_extensions.Literal["double","sint64","bool","string","bytes","double_list","sint64_list","bool_list","string_list","bytes_list"]]: ... +global___ConfigsRecordValue = ConfigsRecordValue + +class ParametersRecord(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + DATA_KEYS_FIELD_NUMBER: builtins.int + DATA_VALUES_FIELD_NUMBER: builtins.int + @property + def data_keys(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text]: ... + @property + def data_values(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Array]: ... + def __init__(self, + *, + data_keys: typing.Optional[typing.Iterable[typing.Text]] = ..., + data_values: typing.Optional[typing.Iterable[global___Array]] = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["data_keys",b"data_keys","data_values",b"data_values"]) -> None: ... +global___ParametersRecord = ParametersRecord + +class MetricsRecord(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + class DataEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: typing.Text + @property + def value(self) -> global___MetricsRecordValue: ... + def __init__(self, + *, + key: typing.Text = ..., + value: typing.Optional[global___MetricsRecordValue] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... + + DATA_FIELD_NUMBER: builtins.int + @property + def data(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, global___MetricsRecordValue]: ... + def __init__(self, + *, + data: typing.Optional[typing.Mapping[typing.Text, global___MetricsRecordValue]] = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["data",b"data"]) -> None: ... +global___MetricsRecord = MetricsRecord + +class ConfigsRecord(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + class DataEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: typing.Text + @property + def value(self) -> global___ConfigsRecordValue: ... + def __init__(self, + *, + key: typing.Text = ..., + value: typing.Optional[global___ConfigsRecordValue] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... + + DATA_FIELD_NUMBER: builtins.int + @property + def data(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, global___ConfigsRecordValue]: ... + def __init__(self, + *, + data: typing.Optional[typing.Mapping[typing.Text, global___ConfigsRecordValue]] = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["data",b"data"]) -> None: ... +global___ConfigsRecord = ConfigsRecord + +class RecordSet(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + class ParametersEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: typing.Text + @property + def value(self) -> global___ParametersRecord: ... + def __init__(self, + *, + key: typing.Text = ..., + value: typing.Optional[global___ParametersRecord] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... + + class MetricsEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: typing.Text + @property + def value(self) -> global___MetricsRecord: ... + def __init__(self, + *, + key: typing.Text = ..., + value: typing.Optional[global___MetricsRecord] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... + + class ConfigsEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: typing.Text + @property + def value(self) -> global___ConfigsRecord: ... + def __init__(self, + *, + key: typing.Text = ..., + value: typing.Optional[global___ConfigsRecord] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... + + PARAMETERS_FIELD_NUMBER: builtins.int + METRICS_FIELD_NUMBER: builtins.int + CONFIGS_FIELD_NUMBER: builtins.int + @property + def parameters(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, global___ParametersRecord]: ... + @property + def metrics(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, global___MetricsRecord]: ... + @property + def configs(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, global___ConfigsRecord]: ... + def __init__(self, + *, + parameters: typing.Optional[typing.Mapping[typing.Text, global___ParametersRecord]] = ..., + metrics: typing.Optional[typing.Mapping[typing.Text, global___MetricsRecord]] = ..., + configs: typing.Optional[typing.Mapping[typing.Text, global___ConfigsRecord]] = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["configs",b"configs","metrics",b"metrics","parameters",b"parameters"]) -> None: ... +global___RecordSet = RecordSet diff --git a/src/py/flwr/proto/recordset_pb2_grpc.py b/src/py/flwr/proto/recordset_pb2_grpc.py new file mode 100644 index 000000000000..2daafffebfc8 --- /dev/null +++ b/src/py/flwr/proto/recordset_pb2_grpc.py @@ -0,0 +1,4 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + diff --git a/src/py/flwr/proto/recordset_pb2_grpc.pyi b/src/py/flwr/proto/recordset_pb2_grpc.pyi new file mode 100644 index 000000000000..f3a5a087ef5d --- /dev/null +++ b/src/py/flwr/proto/recordset_pb2_grpc.pyi @@ -0,0 +1,4 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +""" diff --git a/src/py/flwr/proto/task_pb2.py b/src/py/flwr/proto/task_pb2.py index 6d8cf8fd3656..4d5f863e88dd 100644 --- a/src/py/flwr/proto/task_pb2.py +++ b/src/py/flwr/proto/task_pb2.py @@ -1,148 +1,34 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: flwr/proto/task.proto +# Protobuf Python Version: 4.25.0 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2 +from flwr.proto import recordset_pb2 as flwr_dot_proto_dot_recordset__pb2 from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2 +from flwr.proto import error_pb2 as flwr_dot_proto_dot_error__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xbe\x02\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\t\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x0b\n\x03ttl\x18\x05 \x01(\t\x12\x10\n\x08\x61ncestry\x18\x06 \x03(\t\x12)\n\x02sa\x18\x07 \x01(\x0b\x32\x1d.flwr.proto.SecureAggregation\x12<\n\x15legacy_server_message\x18\x65 \x01(\x0b\x32\x19.flwr.proto.ServerMessageB\x02\x18\x01\x12<\n\x15legacy_client_message\x18\x66 \x01(\x0b\x32\x19.flwr.proto.ClientMessageB\x02\x18\x01\"a\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x13\n\x0bworkload_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"a\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x13\n\x0bworkload_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\xf3\x03\n\x05Value\x12\x10\n\x06\x64ouble\x18\x01 \x01(\x01H\x00\x12\x10\n\x06sint64\x18\x02 \x01(\x12H\x00\x12\x0e\n\x04\x62ool\x18\x03 \x01(\x08H\x00\x12\x10\n\x06string\x18\x04 \x01(\tH\x00\x12\x0f\n\x05\x62ytes\x18\x05 \x01(\x0cH\x00\x12\x33\n\x0b\x64ouble_list\x18\x15 \x01(\x0b\x32\x1c.flwr.proto.Value.DoubleListH\x00\x12\x33\n\x0bsint64_list\x18\x16 \x01(\x0b\x32\x1c.flwr.proto.Value.Sint64ListH\x00\x12/\n\tbool_list\x18\x17 \x01(\x0b\x32\x1a.flwr.proto.Value.BoolListH\x00\x12\x33\n\x0bstring_list\x18\x18 \x01(\x0b\x32\x1c.flwr.proto.Value.StringListH\x00\x12\x31\n\nbytes_list\x18\x19 \x01(\x0b\x32\x1b.flwr.proto.Value.BytesListH\x00\x1a\x1a\n\nDoubleList\x12\x0c\n\x04vals\x18\x01 \x03(\x01\x1a\x1a\n\nSint64List\x12\x0c\n\x04vals\x18\x01 \x03(\x12\x1a\x18\n\x08\x42oolList\x12\x0c\n\x04vals\x18\x01 \x03(\x08\x1a\x1a\n\nStringList\x12\x0c\n\x04vals\x18\x01 \x03(\t\x1a\x19\n\tBytesList\x12\x0c\n\x04vals\x18\x01 \x03(\x0c\x42\x07\n\x05value\"\xa0\x01\n\x11SecureAggregation\x12\x44\n\x0cnamed_values\x18\x01 \x03(\x0b\x32..flwr.proto.SecureAggregation.NamedValuesEntry\x1a\x45\n\x10NamedValuesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12 \n\x05value\x18\x02 \x01(\x0b\x32\x11.flwr.proto.Value:\x02\x38\x01\x62\x06proto3') - - - -_TASK = DESCRIPTOR.message_types_by_name['Task'] -_TASKINS = DESCRIPTOR.message_types_by_name['TaskIns'] -_TASKRES = DESCRIPTOR.message_types_by_name['TaskRes'] -_VALUE = DESCRIPTOR.message_types_by_name['Value'] -_VALUE_DOUBLELIST = _VALUE.nested_types_by_name['DoubleList'] -_VALUE_SINT64LIST = _VALUE.nested_types_by_name['Sint64List'] -_VALUE_BOOLLIST = _VALUE.nested_types_by_name['BoolList'] -_VALUE_STRINGLIST = _VALUE.nested_types_by_name['StringList'] -_VALUE_BYTESLIST = _VALUE.nested_types_by_name['BytesList'] -_SECUREAGGREGATION = DESCRIPTOR.message_types_by_name['SecureAggregation'] -_SECUREAGGREGATION_NAMEDVALUESENTRY = _SECUREAGGREGATION.nested_types_by_name['NamedValuesEntry'] -Task = _reflection.GeneratedProtocolMessageType('Task', (_message.Message,), { - 'DESCRIPTOR' : _TASK, - '__module__' : 'flwr.proto.task_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.Task) - }) -_sym_db.RegisterMessage(Task) - -TaskIns = _reflection.GeneratedProtocolMessageType('TaskIns', (_message.Message,), { - 'DESCRIPTOR' : _TASKINS, - '__module__' : 'flwr.proto.task_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.TaskIns) - }) -_sym_db.RegisterMessage(TaskIns) - -TaskRes = _reflection.GeneratedProtocolMessageType('TaskRes', (_message.Message,), { - 'DESCRIPTOR' : _TASKRES, - '__module__' : 'flwr.proto.task_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.TaskRes) - }) -_sym_db.RegisterMessage(TaskRes) - -Value = _reflection.GeneratedProtocolMessageType('Value', (_message.Message,), { - - 'DoubleList' : _reflection.GeneratedProtocolMessageType('DoubleList', (_message.Message,), { - 'DESCRIPTOR' : _VALUE_DOUBLELIST, - '__module__' : 'flwr.proto.task_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.Value.DoubleList) - }) - , - - 'Sint64List' : _reflection.GeneratedProtocolMessageType('Sint64List', (_message.Message,), { - 'DESCRIPTOR' : _VALUE_SINT64LIST, - '__module__' : 'flwr.proto.task_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.Value.Sint64List) - }) - , - - 'BoolList' : _reflection.GeneratedProtocolMessageType('BoolList', (_message.Message,), { - 'DESCRIPTOR' : _VALUE_BOOLLIST, - '__module__' : 'flwr.proto.task_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.Value.BoolList) - }) - , - - 'StringList' : _reflection.GeneratedProtocolMessageType('StringList', (_message.Message,), { - 'DESCRIPTOR' : _VALUE_STRINGLIST, - '__module__' : 'flwr.proto.task_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.Value.StringList) - }) - , - - 'BytesList' : _reflection.GeneratedProtocolMessageType('BytesList', (_message.Message,), { - 'DESCRIPTOR' : _VALUE_BYTESLIST, - '__module__' : 'flwr.proto.task_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.Value.BytesList) - }) - , - 'DESCRIPTOR' : _VALUE, - '__module__' : 'flwr.proto.task_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.Value) - }) -_sym_db.RegisterMessage(Value) -_sym_db.RegisterMessage(Value.DoubleList) -_sym_db.RegisterMessage(Value.Sint64List) -_sym_db.RegisterMessage(Value.BoolList) -_sym_db.RegisterMessage(Value.StringList) -_sym_db.RegisterMessage(Value.BytesList) - -SecureAggregation = _reflection.GeneratedProtocolMessageType('SecureAggregation', (_message.Message,), { - - 'NamedValuesEntry' : _reflection.GeneratedProtocolMessageType('NamedValuesEntry', (_message.Message,), { - 'DESCRIPTOR' : _SECUREAGGREGATION_NAMEDVALUESENTRY, - '__module__' : 'flwr.proto.task_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.SecureAggregation.NamedValuesEntry) - }) - , - 'DESCRIPTOR' : _SECUREAGGREGATION, - '__module__' : 'flwr.proto.task_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.SecureAggregation) - }) -_sym_db.RegisterMessage(SecureAggregation) -_sym_db.RegisterMessage(SecureAggregation.NamedValuesEntry) +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x16\x66lwr/proto/error.proto\"\xf6\x01\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\t\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x0b\n\x03ttl\x18\x05 \x01(\t\x12\x10\n\x08\x61ncestry\x18\x06 \x03(\t\x12\x11\n\ttask_type\x18\x07 \x01(\t\x12(\n\trecordset\x18\x08 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\t \x01(\x0b\x32\x11.flwr.proto.Error\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Taskb\x06proto3') +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.task_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _TASK.fields_by_name['legacy_server_message']._options = None - _TASK.fields_by_name['legacy_server_message']._serialized_options = b'\030\001' - _TASK.fields_by_name['legacy_client_message']._options = None - _TASK.fields_by_name['legacy_client_message']._serialized_options = b'\030\001' - _SECUREAGGREGATION_NAMEDVALUESENTRY._options = None - _SECUREAGGREGATION_NAMEDVALUESENTRY._serialized_options = b'8\001' - _TASK._serialized_start=89 - _TASK._serialized_end=407 - _TASKINS._serialized_start=409 - _TASKINS._serialized_end=506 - _TASKRES._serialized_start=508 - _TASKRES._serialized_end=605 - _VALUE._serialized_start=608 - _VALUE._serialized_end=1107 - _VALUE_DOUBLELIST._serialized_start=963 - _VALUE_DOUBLELIST._serialized_end=989 - _VALUE_SINT64LIST._serialized_start=991 - _VALUE_SINT64LIST._serialized_end=1017 - _VALUE_BOOLLIST._serialized_start=1019 - _VALUE_BOOLLIST._serialized_end=1043 - _VALUE_STRINGLIST._serialized_start=1045 - _VALUE_STRINGLIST._serialized_end=1071 - _VALUE_BYTESLIST._serialized_start=1073 - _VALUE_BYTESLIST._serialized_end=1098 - _SECUREAGGREGATION._serialized_start=1110 - _SECUREAGGREGATION._serialized_end=1270 - _SECUREAGGREGATION_NAMEDVALUESENTRY._serialized_start=1201 - _SECUREAGGREGATION_NAMEDVALUESENTRY._serialized_end=1270 + _globals['_TASK']._serialized_start=141 + _globals['_TASK']._serialized_end=387 + _globals['_TASKINS']._serialized_start=389 + _globals['_TASKINS']._serialized_end=481 + _globals['_TASKRES']._serialized_start=483 + _globals['_TASKRES']._serialized_end=575 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/task_pb2.pyi b/src/py/flwr/proto/task_pb2.pyi index 7cf96cb61edf..b9c10139cfb3 100644 --- a/src/py/flwr/proto/task_pb2.pyi +++ b/src/py/flwr/proto/task_pb2.pyi @@ -3,8 +3,9 @@ isort:skip_file """ import builtins +import flwr.proto.error_pb2 import flwr.proto.node_pb2 -import flwr.proto.transport_pb2 +import flwr.proto.recordset_pb2 import google.protobuf.descriptor import google.protobuf.internal.containers import google.protobuf.message @@ -21,9 +22,9 @@ class Task(google.protobuf.message.Message): DELIVERED_AT_FIELD_NUMBER: builtins.int TTL_FIELD_NUMBER: builtins.int ANCESTRY_FIELD_NUMBER: builtins.int - SA_FIELD_NUMBER: builtins.int - LEGACY_SERVER_MESSAGE_FIELD_NUMBER: builtins.int - LEGACY_CLIENT_MESSAGE_FIELD_NUMBER: builtins.int + TASK_TYPE_FIELD_NUMBER: builtins.int + RECORDSET_FIELD_NUMBER: builtins.int + ERROR_FIELD_NUMBER: builtins.int @property def producer(self) -> flwr.proto.node_pb2.Node: ... @property @@ -33,12 +34,11 @@ class Task(google.protobuf.message.Message): ttl: typing.Text @property def ancestry(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text]: ... + task_type: typing.Text @property - def sa(self) -> global___SecureAggregation: ... + def recordset(self) -> flwr.proto.recordset_pb2.RecordSet: ... @property - def legacy_server_message(self) -> flwr.proto.transport_pb2.ServerMessage: ... - @property - def legacy_client_message(self) -> flwr.proto.transport_pb2.ClientMessage: ... + def error(self) -> flwr.proto.error_pb2.Error: ... def __init__(self, *, producer: typing.Optional[flwr.proto.node_pb2.Node] = ..., @@ -47,185 +47,54 @@ class Task(google.protobuf.message.Message): delivered_at: typing.Text = ..., ttl: typing.Text = ..., ancestry: typing.Optional[typing.Iterable[typing.Text]] = ..., - sa: typing.Optional[global___SecureAggregation] = ..., - legacy_server_message: typing.Optional[flwr.proto.transport_pb2.ServerMessage] = ..., - legacy_client_message: typing.Optional[flwr.proto.transport_pb2.ClientMessage] = ..., + task_type: typing.Text = ..., + recordset: typing.Optional[flwr.proto.recordset_pb2.RecordSet] = ..., + error: typing.Optional[flwr.proto.error_pb2.Error] = ..., ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["consumer",b"consumer","legacy_client_message",b"legacy_client_message","legacy_server_message",b"legacy_server_message","producer",b"producer","sa",b"sa"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["ancestry",b"ancestry","consumer",b"consumer","created_at",b"created_at","delivered_at",b"delivered_at","legacy_client_message",b"legacy_client_message","legacy_server_message",b"legacy_server_message","producer",b"producer","sa",b"sa","ttl",b"ttl"]) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["consumer",b"consumer","error",b"error","producer",b"producer","recordset",b"recordset"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["ancestry",b"ancestry","consumer",b"consumer","created_at",b"created_at","delivered_at",b"delivered_at","error",b"error","producer",b"producer","recordset",b"recordset","task_type",b"task_type","ttl",b"ttl"]) -> None: ... global___Task = Task class TaskIns(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor TASK_ID_FIELD_NUMBER: builtins.int GROUP_ID_FIELD_NUMBER: builtins.int - WORKLOAD_ID_FIELD_NUMBER: builtins.int + RUN_ID_FIELD_NUMBER: builtins.int TASK_FIELD_NUMBER: builtins.int task_id: typing.Text group_id: typing.Text - workload_id: builtins.int + run_id: builtins.int @property def task(self) -> global___Task: ... def __init__(self, *, task_id: typing.Text = ..., group_id: typing.Text = ..., - workload_id: builtins.int = ..., + run_id: builtins.int = ..., task: typing.Optional[global___Task] = ..., ) -> None: ... def HasField(self, field_name: typing_extensions.Literal["task",b"task"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["group_id",b"group_id","task",b"task","task_id",b"task_id","workload_id",b"workload_id"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["group_id",b"group_id","run_id",b"run_id","task",b"task","task_id",b"task_id"]) -> None: ... global___TaskIns = TaskIns class TaskRes(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor TASK_ID_FIELD_NUMBER: builtins.int GROUP_ID_FIELD_NUMBER: builtins.int - WORKLOAD_ID_FIELD_NUMBER: builtins.int + RUN_ID_FIELD_NUMBER: builtins.int TASK_FIELD_NUMBER: builtins.int task_id: typing.Text group_id: typing.Text - workload_id: builtins.int + run_id: builtins.int @property def task(self) -> global___Task: ... def __init__(self, *, task_id: typing.Text = ..., group_id: typing.Text = ..., - workload_id: builtins.int = ..., + run_id: builtins.int = ..., task: typing.Optional[global___Task] = ..., ) -> None: ... def HasField(self, field_name: typing_extensions.Literal["task",b"task"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["group_id",b"group_id","task",b"task","task_id",b"task_id","workload_id",b"workload_id"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["group_id",b"group_id","run_id",b"run_id","task",b"task","task_id",b"task_id"]) -> None: ... global___TaskRes = TaskRes - -class Value(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - class DoubleList(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - VALS_FIELD_NUMBER: builtins.int - @property - def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.float]: ... - def __init__(self, - *, - vals: typing.Optional[typing.Iterable[builtins.float]] = ..., - ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... - - class Sint64List(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - VALS_FIELD_NUMBER: builtins.int - @property - def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... - def __init__(self, - *, - vals: typing.Optional[typing.Iterable[builtins.int]] = ..., - ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... - - class BoolList(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - VALS_FIELD_NUMBER: builtins.int - @property - def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bool]: ... - def __init__(self, - *, - vals: typing.Optional[typing.Iterable[builtins.bool]] = ..., - ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... - - class StringList(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - VALS_FIELD_NUMBER: builtins.int - @property - def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text]: ... - def __init__(self, - *, - vals: typing.Optional[typing.Iterable[typing.Text]] = ..., - ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... - - class BytesList(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - VALS_FIELD_NUMBER: builtins.int - @property - def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ... - def __init__(self, - *, - vals: typing.Optional[typing.Iterable[builtins.bytes]] = ..., - ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... - - DOUBLE_FIELD_NUMBER: builtins.int - SINT64_FIELD_NUMBER: builtins.int - BOOL_FIELD_NUMBER: builtins.int - STRING_FIELD_NUMBER: builtins.int - BYTES_FIELD_NUMBER: builtins.int - DOUBLE_LIST_FIELD_NUMBER: builtins.int - SINT64_LIST_FIELD_NUMBER: builtins.int - BOOL_LIST_FIELD_NUMBER: builtins.int - STRING_LIST_FIELD_NUMBER: builtins.int - BYTES_LIST_FIELD_NUMBER: builtins.int - double: builtins.float - """Single element""" - - sint64: builtins.int - bool: builtins.bool - string: typing.Text - bytes: builtins.bytes - @property - def double_list(self) -> global___Value.DoubleList: - """List types""" - pass - @property - def sint64_list(self) -> global___Value.Sint64List: ... - @property - def bool_list(self) -> global___Value.BoolList: ... - @property - def string_list(self) -> global___Value.StringList: ... - @property - def bytes_list(self) -> global___Value.BytesList: ... - def __init__(self, - *, - double: builtins.float = ..., - sint64: builtins.int = ..., - bool: builtins.bool = ..., - string: typing.Text = ..., - bytes: builtins.bytes = ..., - double_list: typing.Optional[global___Value.DoubleList] = ..., - sint64_list: typing.Optional[global___Value.Sint64List] = ..., - bool_list: typing.Optional[global___Value.BoolList] = ..., - string_list: typing.Optional[global___Value.StringList] = ..., - bytes_list: typing.Optional[global___Value.BytesList] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["bool",b"bool","bool_list",b"bool_list","bytes",b"bytes","bytes_list",b"bytes_list","double",b"double","double_list",b"double_list","sint64",b"sint64","sint64_list",b"sint64_list","string",b"string","string_list",b"string_list","value",b"value"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["bool",b"bool","bool_list",b"bool_list","bytes",b"bytes","bytes_list",b"bytes_list","double",b"double","double_list",b"double_list","sint64",b"sint64","sint64_list",b"sint64_list","string",b"string","string_list",b"string_list","value",b"value"]) -> None: ... - def WhichOneof(self, oneof_group: typing_extensions.Literal["value",b"value"]) -> typing.Optional[typing_extensions.Literal["double","sint64","bool","string","bytes","double_list","sint64_list","bool_list","string_list","bytes_list"]]: ... -global___Value = Value - -class SecureAggregation(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - class NamedValuesEntry(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - KEY_FIELD_NUMBER: builtins.int - VALUE_FIELD_NUMBER: builtins.int - key: typing.Text - @property - def value(self) -> global___Value: ... - def __init__(self, - *, - key: typing.Text = ..., - value: typing.Optional[global___Value] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... - - NAMED_VALUES_FIELD_NUMBER: builtins.int - @property - def named_values(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, global___Value]: ... - def __init__(self, - *, - named_values: typing.Optional[typing.Mapping[typing.Text, global___Value]] = ..., - ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["named_values",b"named_values"]) -> None: ... -global___SecureAggregation = SecureAggregation diff --git a/src/py/flwr/proto/transport_pb2.py b/src/py/flwr/proto/transport_pb2.py index 1e3785b0e312..d3aae72b63ab 100644 --- a/src/py/flwr/proto/transport_pb2.py +++ b/src/py/flwr/proto/transport_pb2.py @@ -1,13 +1,12 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: flwr/proto/transport.proto +# Protobuf Python Version: 4.25.0 """Generated protocol buffer code.""" -from google.protobuf.internal import enum_type_wrapper from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -17,281 +16,73 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1a\x66lwr/proto/transport.proto\x12\nflwr.proto\"9\n\x06Status\x12\x1e\n\x04\x63ode\x18\x01 \x01(\x0e\x32\x10.flwr.proto.Code\x12\x0f\n\x07message\x18\x02 \x01(\t\"2\n\nParameters\x12\x0f\n\x07tensors\x18\x01 \x03(\x0c\x12\x13\n\x0btensor_type\x18\x02 \x01(\t\"\xba\x08\n\rServerMessage\x12?\n\rreconnect_ins\x18\x01 \x01(\x0b\x32&.flwr.proto.ServerMessage.ReconnectInsH\x00\x12H\n\x12get_properties_ins\x18\x02 \x01(\x0b\x32*.flwr.proto.ServerMessage.GetPropertiesInsH\x00\x12H\n\x12get_parameters_ins\x18\x03 \x01(\x0b\x32*.flwr.proto.ServerMessage.GetParametersInsH\x00\x12\x33\n\x07\x66it_ins\x18\x04 \x01(\x0b\x32 .flwr.proto.ServerMessage.FitInsH\x00\x12=\n\x0c\x65valuate_ins\x18\x05 \x01(\x0b\x32%.flwr.proto.ServerMessage.EvaluateInsH\x00\x1a\x1f\n\x0cReconnectIns\x12\x0f\n\x07seconds\x18\x01 \x01(\x03\x1a\x9d\x01\n\x10GetPropertiesIns\x12\x46\n\x06\x63onfig\x18\x01 \x03(\x0b\x32\x36.flwr.proto.ServerMessage.GetPropertiesIns.ConfigEntry\x1a\x41\n\x0b\x43onfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\x1a\x9d\x01\n\x10GetParametersIns\x12\x46\n\x06\x63onfig\x18\x01 \x03(\x0b\x32\x36.flwr.proto.ServerMessage.GetParametersIns.ConfigEntry\x1a\x41\n\x0b\x43onfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\x1a\xb5\x01\n\x06\x46itIns\x12*\n\nparameters\x18\x01 \x01(\x0b\x32\x16.flwr.proto.Parameters\x12<\n\x06\x63onfig\x18\x02 \x03(\x0b\x32,.flwr.proto.ServerMessage.FitIns.ConfigEntry\x1a\x41\n\x0b\x43onfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\x1a\xbf\x01\n\x0b\x45valuateIns\x12*\n\nparameters\x18\x01 \x01(\x0b\x32\x16.flwr.proto.Parameters\x12\x41\n\x06\x63onfig\x18\x02 \x03(\x0b\x32\x31.flwr.proto.ServerMessage.EvaluateIns.ConfigEntry\x1a\x41\n\x0b\x43onfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\x42\x05\n\x03msg\"\xa0\t\n\rClientMessage\x12\x41\n\x0e\x64isconnect_res\x18\x01 \x01(\x0b\x32\'.flwr.proto.ClientMessage.DisconnectResH\x00\x12H\n\x12get_properties_res\x18\x02 \x01(\x0b\x32*.flwr.proto.ClientMessage.GetPropertiesResH\x00\x12H\n\x12get_parameters_res\x18\x03 \x01(\x0b\x32*.flwr.proto.ClientMessage.GetParametersResH\x00\x12\x33\n\x07\x66it_res\x18\x04 \x01(\x0b\x32 .flwr.proto.ClientMessage.FitResH\x00\x12=\n\x0c\x65valuate_res\x18\x05 \x01(\x0b\x32%.flwr.proto.ClientMessage.EvaluateResH\x00\x1a\x33\n\rDisconnectRes\x12\"\n\x06reason\x18\x01 \x01(\x0e\x32\x12.flwr.proto.Reason\x1a\xcd\x01\n\x10GetPropertiesRes\x12\"\n\x06status\x18\x01 \x01(\x0b\x32\x12.flwr.proto.Status\x12N\n\nproperties\x18\x02 \x03(\x0b\x32:.flwr.proto.ClientMessage.GetPropertiesRes.PropertiesEntry\x1a\x45\n\x0fPropertiesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\x1a\x62\n\x10GetParametersRes\x12\"\n\x06status\x18\x01 \x01(\x0b\x32\x12.flwr.proto.Status\x12*\n\nparameters\x18\x02 \x01(\x0b\x32\x16.flwr.proto.Parameters\x1a\xf2\x01\n\x06\x46itRes\x12\"\n\x06status\x18\x01 \x01(\x0b\x32\x12.flwr.proto.Status\x12*\n\nparameters\x18\x02 \x01(\x0b\x32\x16.flwr.proto.Parameters\x12\x14\n\x0cnum_examples\x18\x03 \x01(\x03\x12>\n\x07metrics\x18\x04 \x03(\x0b\x32-.flwr.proto.ClientMessage.FitRes.MetricsEntry\x1a\x42\n\x0cMetricsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\x1a\xde\x01\n\x0b\x45valuateRes\x12\"\n\x06status\x18\x01 \x01(\x0b\x32\x12.flwr.proto.Status\x12\x0c\n\x04loss\x18\x02 \x01(\x02\x12\x14\n\x0cnum_examples\x18\x03 \x01(\x03\x12\x43\n\x07metrics\x18\x04 \x03(\x0b\x32\x32.flwr.proto.ClientMessage.EvaluateRes.MetricsEntry\x1a\x42\n\x0cMetricsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\x42\x05\n\x03msg\"i\n\x06Scalar\x12\x10\n\x06\x64ouble\x18\x01 \x01(\x01H\x00\x12\x10\n\x06sint64\x18\x08 \x01(\x12H\x00\x12\x0e\n\x04\x62ool\x18\r \x01(\x08H\x00\x12\x10\n\x06string\x18\x0e \x01(\tH\x00\x12\x0f\n\x05\x62ytes\x18\x0f \x01(\x0cH\x00\x42\x08\n\x06scalar*\x8d\x01\n\x04\x43ode\x12\x06\n\x02OK\x10\x00\x12\"\n\x1eGET_PROPERTIES_NOT_IMPLEMENTED\x10\x01\x12\"\n\x1eGET_PARAMETERS_NOT_IMPLEMENTED\x10\x02\x12\x17\n\x13\x46IT_NOT_IMPLEMENTED\x10\x03\x12\x1c\n\x18\x45VALUATE_NOT_IMPLEMENTED\x10\x04*[\n\x06Reason\x12\x0b\n\x07UNKNOWN\x10\x00\x12\r\n\tRECONNECT\x10\x01\x12\x16\n\x12POWER_DISCONNECTED\x10\x02\x12\x14\n\x10WIFI_UNAVAILABLE\x10\x03\x12\x07\n\x03\x41\x43K\x10\x04\x32S\n\rFlowerService\x12\x42\n\x04Join\x12\x19.flwr.proto.ClientMessage\x1a\x19.flwr.proto.ServerMessage\"\x00(\x01\x30\x01\x62\x06proto3') -_CODE = DESCRIPTOR.enum_types_by_name['Code'] -Code = enum_type_wrapper.EnumTypeWrapper(_CODE) -_REASON = DESCRIPTOR.enum_types_by_name['Reason'] -Reason = enum_type_wrapper.EnumTypeWrapper(_REASON) -OK = 0 -GET_PROPERTIES_NOT_IMPLEMENTED = 1 -GET_PARAMETERS_NOT_IMPLEMENTED = 2 -FIT_NOT_IMPLEMENTED = 3 -EVALUATE_NOT_IMPLEMENTED = 4 -UNKNOWN = 0 -RECONNECT = 1 -POWER_DISCONNECTED = 2 -WIFI_UNAVAILABLE = 3 -ACK = 4 - - -_STATUS = DESCRIPTOR.message_types_by_name['Status'] -_PARAMETERS = DESCRIPTOR.message_types_by_name['Parameters'] -_SERVERMESSAGE = DESCRIPTOR.message_types_by_name['ServerMessage'] -_SERVERMESSAGE_RECONNECTINS = _SERVERMESSAGE.nested_types_by_name['ReconnectIns'] -_SERVERMESSAGE_GETPROPERTIESINS = _SERVERMESSAGE.nested_types_by_name['GetPropertiesIns'] -_SERVERMESSAGE_GETPROPERTIESINS_CONFIGENTRY = _SERVERMESSAGE_GETPROPERTIESINS.nested_types_by_name['ConfigEntry'] -_SERVERMESSAGE_GETPARAMETERSINS = _SERVERMESSAGE.nested_types_by_name['GetParametersIns'] -_SERVERMESSAGE_GETPARAMETERSINS_CONFIGENTRY = _SERVERMESSAGE_GETPARAMETERSINS.nested_types_by_name['ConfigEntry'] -_SERVERMESSAGE_FITINS = _SERVERMESSAGE.nested_types_by_name['FitIns'] -_SERVERMESSAGE_FITINS_CONFIGENTRY = _SERVERMESSAGE_FITINS.nested_types_by_name['ConfigEntry'] -_SERVERMESSAGE_EVALUATEINS = _SERVERMESSAGE.nested_types_by_name['EvaluateIns'] -_SERVERMESSAGE_EVALUATEINS_CONFIGENTRY = _SERVERMESSAGE_EVALUATEINS.nested_types_by_name['ConfigEntry'] -_CLIENTMESSAGE = DESCRIPTOR.message_types_by_name['ClientMessage'] -_CLIENTMESSAGE_DISCONNECTRES = _CLIENTMESSAGE.nested_types_by_name['DisconnectRes'] -_CLIENTMESSAGE_GETPROPERTIESRES = _CLIENTMESSAGE.nested_types_by_name['GetPropertiesRes'] -_CLIENTMESSAGE_GETPROPERTIESRES_PROPERTIESENTRY = _CLIENTMESSAGE_GETPROPERTIESRES.nested_types_by_name['PropertiesEntry'] -_CLIENTMESSAGE_GETPARAMETERSRES = _CLIENTMESSAGE.nested_types_by_name['GetParametersRes'] -_CLIENTMESSAGE_FITRES = _CLIENTMESSAGE.nested_types_by_name['FitRes'] -_CLIENTMESSAGE_FITRES_METRICSENTRY = _CLIENTMESSAGE_FITRES.nested_types_by_name['MetricsEntry'] -_CLIENTMESSAGE_EVALUATERES = _CLIENTMESSAGE.nested_types_by_name['EvaluateRes'] -_CLIENTMESSAGE_EVALUATERES_METRICSENTRY = _CLIENTMESSAGE_EVALUATERES.nested_types_by_name['MetricsEntry'] -_SCALAR = DESCRIPTOR.message_types_by_name['Scalar'] -Status = _reflection.GeneratedProtocolMessageType('Status', (_message.Message,), { - 'DESCRIPTOR' : _STATUS, - '__module__' : 'flwr.proto.transport_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.Status) - }) -_sym_db.RegisterMessage(Status) - -Parameters = _reflection.GeneratedProtocolMessageType('Parameters', (_message.Message,), { - 'DESCRIPTOR' : _PARAMETERS, - '__module__' : 'flwr.proto.transport_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.Parameters) - }) -_sym_db.RegisterMessage(Parameters) - -ServerMessage = _reflection.GeneratedProtocolMessageType('ServerMessage', (_message.Message,), { - - 'ReconnectIns' : _reflection.GeneratedProtocolMessageType('ReconnectIns', (_message.Message,), { - 'DESCRIPTOR' : _SERVERMESSAGE_RECONNECTINS, - '__module__' : 'flwr.proto.transport_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.ServerMessage.ReconnectIns) - }) - , - - 'GetPropertiesIns' : _reflection.GeneratedProtocolMessageType('GetPropertiesIns', (_message.Message,), { - - 'ConfigEntry' : _reflection.GeneratedProtocolMessageType('ConfigEntry', (_message.Message,), { - 'DESCRIPTOR' : _SERVERMESSAGE_GETPROPERTIESINS_CONFIGENTRY, - '__module__' : 'flwr.proto.transport_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.ServerMessage.GetPropertiesIns.ConfigEntry) - }) - , - 'DESCRIPTOR' : _SERVERMESSAGE_GETPROPERTIESINS, - '__module__' : 'flwr.proto.transport_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.ServerMessage.GetPropertiesIns) - }) - , - - 'GetParametersIns' : _reflection.GeneratedProtocolMessageType('GetParametersIns', (_message.Message,), { - - 'ConfigEntry' : _reflection.GeneratedProtocolMessageType('ConfigEntry', (_message.Message,), { - 'DESCRIPTOR' : _SERVERMESSAGE_GETPARAMETERSINS_CONFIGENTRY, - '__module__' : 'flwr.proto.transport_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.ServerMessage.GetParametersIns.ConfigEntry) - }) - , - 'DESCRIPTOR' : _SERVERMESSAGE_GETPARAMETERSINS, - '__module__' : 'flwr.proto.transport_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.ServerMessage.GetParametersIns) - }) - , - - 'FitIns' : _reflection.GeneratedProtocolMessageType('FitIns', (_message.Message,), { - - 'ConfigEntry' : _reflection.GeneratedProtocolMessageType('ConfigEntry', (_message.Message,), { - 'DESCRIPTOR' : _SERVERMESSAGE_FITINS_CONFIGENTRY, - '__module__' : 'flwr.proto.transport_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.ServerMessage.FitIns.ConfigEntry) - }) - , - 'DESCRIPTOR' : _SERVERMESSAGE_FITINS, - '__module__' : 'flwr.proto.transport_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.ServerMessage.FitIns) - }) - , - - 'EvaluateIns' : _reflection.GeneratedProtocolMessageType('EvaluateIns', (_message.Message,), { - - 'ConfigEntry' : _reflection.GeneratedProtocolMessageType('ConfigEntry', (_message.Message,), { - 'DESCRIPTOR' : _SERVERMESSAGE_EVALUATEINS_CONFIGENTRY, - '__module__' : 'flwr.proto.transport_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.ServerMessage.EvaluateIns.ConfigEntry) - }) - , - 'DESCRIPTOR' : _SERVERMESSAGE_EVALUATEINS, - '__module__' : 'flwr.proto.transport_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.ServerMessage.EvaluateIns) - }) - , - 'DESCRIPTOR' : _SERVERMESSAGE, - '__module__' : 'flwr.proto.transport_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.ServerMessage) - }) -_sym_db.RegisterMessage(ServerMessage) -_sym_db.RegisterMessage(ServerMessage.ReconnectIns) -_sym_db.RegisterMessage(ServerMessage.GetPropertiesIns) -_sym_db.RegisterMessage(ServerMessage.GetPropertiesIns.ConfigEntry) -_sym_db.RegisterMessage(ServerMessage.GetParametersIns) -_sym_db.RegisterMessage(ServerMessage.GetParametersIns.ConfigEntry) -_sym_db.RegisterMessage(ServerMessage.FitIns) -_sym_db.RegisterMessage(ServerMessage.FitIns.ConfigEntry) -_sym_db.RegisterMessage(ServerMessage.EvaluateIns) -_sym_db.RegisterMessage(ServerMessage.EvaluateIns.ConfigEntry) - -ClientMessage = _reflection.GeneratedProtocolMessageType('ClientMessage', (_message.Message,), { - - 'DisconnectRes' : _reflection.GeneratedProtocolMessageType('DisconnectRes', (_message.Message,), { - 'DESCRIPTOR' : _CLIENTMESSAGE_DISCONNECTRES, - '__module__' : 'flwr.proto.transport_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.ClientMessage.DisconnectRes) - }) - , - - 'GetPropertiesRes' : _reflection.GeneratedProtocolMessageType('GetPropertiesRes', (_message.Message,), { - - 'PropertiesEntry' : _reflection.GeneratedProtocolMessageType('PropertiesEntry', (_message.Message,), { - 'DESCRIPTOR' : _CLIENTMESSAGE_GETPROPERTIESRES_PROPERTIESENTRY, - '__module__' : 'flwr.proto.transport_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.ClientMessage.GetPropertiesRes.PropertiesEntry) - }) - , - 'DESCRIPTOR' : _CLIENTMESSAGE_GETPROPERTIESRES, - '__module__' : 'flwr.proto.transport_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.ClientMessage.GetPropertiesRes) - }) - , - - 'GetParametersRes' : _reflection.GeneratedProtocolMessageType('GetParametersRes', (_message.Message,), { - 'DESCRIPTOR' : _CLIENTMESSAGE_GETPARAMETERSRES, - '__module__' : 'flwr.proto.transport_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.ClientMessage.GetParametersRes) - }) - , - - 'FitRes' : _reflection.GeneratedProtocolMessageType('FitRes', (_message.Message,), { - - 'MetricsEntry' : _reflection.GeneratedProtocolMessageType('MetricsEntry', (_message.Message,), { - 'DESCRIPTOR' : _CLIENTMESSAGE_FITRES_METRICSENTRY, - '__module__' : 'flwr.proto.transport_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.ClientMessage.FitRes.MetricsEntry) - }) - , - 'DESCRIPTOR' : _CLIENTMESSAGE_FITRES, - '__module__' : 'flwr.proto.transport_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.ClientMessage.FitRes) - }) - , - - 'EvaluateRes' : _reflection.GeneratedProtocolMessageType('EvaluateRes', (_message.Message,), { - - 'MetricsEntry' : _reflection.GeneratedProtocolMessageType('MetricsEntry', (_message.Message,), { - 'DESCRIPTOR' : _CLIENTMESSAGE_EVALUATERES_METRICSENTRY, - '__module__' : 'flwr.proto.transport_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.ClientMessage.EvaluateRes.MetricsEntry) - }) - , - 'DESCRIPTOR' : _CLIENTMESSAGE_EVALUATERES, - '__module__' : 'flwr.proto.transport_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.ClientMessage.EvaluateRes) - }) - , - 'DESCRIPTOR' : _CLIENTMESSAGE, - '__module__' : 'flwr.proto.transport_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.ClientMessage) - }) -_sym_db.RegisterMessage(ClientMessage) -_sym_db.RegisterMessage(ClientMessage.DisconnectRes) -_sym_db.RegisterMessage(ClientMessage.GetPropertiesRes) -_sym_db.RegisterMessage(ClientMessage.GetPropertiesRes.PropertiesEntry) -_sym_db.RegisterMessage(ClientMessage.GetParametersRes) -_sym_db.RegisterMessage(ClientMessage.FitRes) -_sym_db.RegisterMessage(ClientMessage.FitRes.MetricsEntry) -_sym_db.RegisterMessage(ClientMessage.EvaluateRes) -_sym_db.RegisterMessage(ClientMessage.EvaluateRes.MetricsEntry) - -Scalar = _reflection.GeneratedProtocolMessageType('Scalar', (_message.Message,), { - 'DESCRIPTOR' : _SCALAR, - '__module__' : 'flwr.proto.transport_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.Scalar) - }) -_sym_db.RegisterMessage(Scalar) - -_FLOWERSERVICE = DESCRIPTOR.services_by_name['FlowerService'] +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.transport_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _SERVERMESSAGE_GETPROPERTIESINS_CONFIGENTRY._options = None - _SERVERMESSAGE_GETPROPERTIESINS_CONFIGENTRY._serialized_options = b'8\001' - _SERVERMESSAGE_GETPARAMETERSINS_CONFIGENTRY._options = None - _SERVERMESSAGE_GETPARAMETERSINS_CONFIGENTRY._serialized_options = b'8\001' - _SERVERMESSAGE_FITINS_CONFIGENTRY._options = None - _SERVERMESSAGE_FITINS_CONFIGENTRY._serialized_options = b'8\001' - _SERVERMESSAGE_EVALUATEINS_CONFIGENTRY._options = None - _SERVERMESSAGE_EVALUATEINS_CONFIGENTRY._serialized_options = b'8\001' - _CLIENTMESSAGE_GETPROPERTIESRES_PROPERTIESENTRY._options = None - _CLIENTMESSAGE_GETPROPERTIESRES_PROPERTIESENTRY._serialized_options = b'8\001' - _CLIENTMESSAGE_FITRES_METRICSENTRY._options = None - _CLIENTMESSAGE_FITRES_METRICSENTRY._serialized_options = b'8\001' - _CLIENTMESSAGE_EVALUATERES_METRICSENTRY._options = None - _CLIENTMESSAGE_EVALUATERES_METRICSENTRY._serialized_options = b'8\001' - _CODE._serialized_start=2533 - _CODE._serialized_end=2674 - _REASON._serialized_start=2676 - _REASON._serialized_end=2767 - _STATUS._serialized_start=42 - _STATUS._serialized_end=99 - _PARAMETERS._serialized_start=101 - _PARAMETERS._serialized_end=151 - _SERVERMESSAGE._serialized_start=154 - _SERVERMESSAGE._serialized_end=1236 - _SERVERMESSAGE_RECONNECTINS._serialized_start=500 - _SERVERMESSAGE_RECONNECTINS._serialized_end=531 - _SERVERMESSAGE_GETPROPERTIESINS._serialized_start=534 - _SERVERMESSAGE_GETPROPERTIESINS._serialized_end=691 - _SERVERMESSAGE_GETPROPERTIESINS_CONFIGENTRY._serialized_start=626 - _SERVERMESSAGE_GETPROPERTIESINS_CONFIGENTRY._serialized_end=691 - _SERVERMESSAGE_GETPARAMETERSINS._serialized_start=694 - _SERVERMESSAGE_GETPARAMETERSINS._serialized_end=851 - _SERVERMESSAGE_GETPARAMETERSINS_CONFIGENTRY._serialized_start=626 - _SERVERMESSAGE_GETPARAMETERSINS_CONFIGENTRY._serialized_end=691 - _SERVERMESSAGE_FITINS._serialized_start=854 - _SERVERMESSAGE_FITINS._serialized_end=1035 - _SERVERMESSAGE_FITINS_CONFIGENTRY._serialized_start=626 - _SERVERMESSAGE_FITINS_CONFIGENTRY._serialized_end=691 - _SERVERMESSAGE_EVALUATEINS._serialized_start=1038 - _SERVERMESSAGE_EVALUATEINS._serialized_end=1229 - _SERVERMESSAGE_EVALUATEINS_CONFIGENTRY._serialized_start=626 - _SERVERMESSAGE_EVALUATEINS_CONFIGENTRY._serialized_end=691 - _CLIENTMESSAGE._serialized_start=1239 - _CLIENTMESSAGE._serialized_end=2423 - _CLIENTMESSAGE_DISCONNECTRES._serialized_start=1587 - _CLIENTMESSAGE_DISCONNECTRES._serialized_end=1638 - _CLIENTMESSAGE_GETPROPERTIESRES._serialized_start=1641 - _CLIENTMESSAGE_GETPROPERTIESRES._serialized_end=1846 - _CLIENTMESSAGE_GETPROPERTIESRES_PROPERTIESENTRY._serialized_start=1777 - _CLIENTMESSAGE_GETPROPERTIESRES_PROPERTIESENTRY._serialized_end=1846 - _CLIENTMESSAGE_GETPARAMETERSRES._serialized_start=1848 - _CLIENTMESSAGE_GETPARAMETERSRES._serialized_end=1946 - _CLIENTMESSAGE_FITRES._serialized_start=1949 - _CLIENTMESSAGE_FITRES._serialized_end=2191 - _CLIENTMESSAGE_FITRES_METRICSENTRY._serialized_start=2125 - _CLIENTMESSAGE_FITRES_METRICSENTRY._serialized_end=2191 - _CLIENTMESSAGE_EVALUATERES._serialized_start=2194 - _CLIENTMESSAGE_EVALUATERES._serialized_end=2416 - _CLIENTMESSAGE_EVALUATERES_METRICSENTRY._serialized_start=2125 - _CLIENTMESSAGE_EVALUATERES_METRICSENTRY._serialized_end=2191 - _SCALAR._serialized_start=2425 - _SCALAR._serialized_end=2530 - _FLOWERSERVICE._serialized_start=2769 - _FLOWERSERVICE._serialized_end=2852 + _globals['_SERVERMESSAGE_GETPROPERTIESINS_CONFIGENTRY']._options = None + _globals['_SERVERMESSAGE_GETPROPERTIESINS_CONFIGENTRY']._serialized_options = b'8\001' + _globals['_SERVERMESSAGE_GETPARAMETERSINS_CONFIGENTRY']._options = None + _globals['_SERVERMESSAGE_GETPARAMETERSINS_CONFIGENTRY']._serialized_options = b'8\001' + _globals['_SERVERMESSAGE_FITINS_CONFIGENTRY']._options = None + _globals['_SERVERMESSAGE_FITINS_CONFIGENTRY']._serialized_options = b'8\001' + _globals['_SERVERMESSAGE_EVALUATEINS_CONFIGENTRY']._options = None + _globals['_SERVERMESSAGE_EVALUATEINS_CONFIGENTRY']._serialized_options = b'8\001' + _globals['_CLIENTMESSAGE_GETPROPERTIESRES_PROPERTIESENTRY']._options = None + _globals['_CLIENTMESSAGE_GETPROPERTIESRES_PROPERTIESENTRY']._serialized_options = b'8\001' + _globals['_CLIENTMESSAGE_FITRES_METRICSENTRY']._options = None + _globals['_CLIENTMESSAGE_FITRES_METRICSENTRY']._serialized_options = b'8\001' + _globals['_CLIENTMESSAGE_EVALUATERES_METRICSENTRY']._options = None + _globals['_CLIENTMESSAGE_EVALUATERES_METRICSENTRY']._serialized_options = b'8\001' + _globals['_CODE']._serialized_start=2533 + _globals['_CODE']._serialized_end=2674 + _globals['_REASON']._serialized_start=2676 + _globals['_REASON']._serialized_end=2767 + _globals['_STATUS']._serialized_start=42 + _globals['_STATUS']._serialized_end=99 + _globals['_PARAMETERS']._serialized_start=101 + _globals['_PARAMETERS']._serialized_end=151 + _globals['_SERVERMESSAGE']._serialized_start=154 + _globals['_SERVERMESSAGE']._serialized_end=1236 + _globals['_SERVERMESSAGE_RECONNECTINS']._serialized_start=500 + _globals['_SERVERMESSAGE_RECONNECTINS']._serialized_end=531 + _globals['_SERVERMESSAGE_GETPROPERTIESINS']._serialized_start=534 + _globals['_SERVERMESSAGE_GETPROPERTIESINS']._serialized_end=691 + _globals['_SERVERMESSAGE_GETPROPERTIESINS_CONFIGENTRY']._serialized_start=626 + _globals['_SERVERMESSAGE_GETPROPERTIESINS_CONFIGENTRY']._serialized_end=691 + _globals['_SERVERMESSAGE_GETPARAMETERSINS']._serialized_start=694 + _globals['_SERVERMESSAGE_GETPARAMETERSINS']._serialized_end=851 + _globals['_SERVERMESSAGE_GETPARAMETERSINS_CONFIGENTRY']._serialized_start=626 + _globals['_SERVERMESSAGE_GETPARAMETERSINS_CONFIGENTRY']._serialized_end=691 + _globals['_SERVERMESSAGE_FITINS']._serialized_start=854 + _globals['_SERVERMESSAGE_FITINS']._serialized_end=1035 + _globals['_SERVERMESSAGE_FITINS_CONFIGENTRY']._serialized_start=626 + _globals['_SERVERMESSAGE_FITINS_CONFIGENTRY']._serialized_end=691 + _globals['_SERVERMESSAGE_EVALUATEINS']._serialized_start=1038 + _globals['_SERVERMESSAGE_EVALUATEINS']._serialized_end=1229 + _globals['_SERVERMESSAGE_EVALUATEINS_CONFIGENTRY']._serialized_start=626 + _globals['_SERVERMESSAGE_EVALUATEINS_CONFIGENTRY']._serialized_end=691 + _globals['_CLIENTMESSAGE']._serialized_start=1239 + _globals['_CLIENTMESSAGE']._serialized_end=2423 + _globals['_CLIENTMESSAGE_DISCONNECTRES']._serialized_start=1587 + _globals['_CLIENTMESSAGE_DISCONNECTRES']._serialized_end=1638 + _globals['_CLIENTMESSAGE_GETPROPERTIESRES']._serialized_start=1641 + _globals['_CLIENTMESSAGE_GETPROPERTIESRES']._serialized_end=1846 + _globals['_CLIENTMESSAGE_GETPROPERTIESRES_PROPERTIESENTRY']._serialized_start=1777 + _globals['_CLIENTMESSAGE_GETPROPERTIESRES_PROPERTIESENTRY']._serialized_end=1846 + _globals['_CLIENTMESSAGE_GETPARAMETERSRES']._serialized_start=1848 + _globals['_CLIENTMESSAGE_GETPARAMETERSRES']._serialized_end=1946 + _globals['_CLIENTMESSAGE_FITRES']._serialized_start=1949 + _globals['_CLIENTMESSAGE_FITRES']._serialized_end=2191 + _globals['_CLIENTMESSAGE_FITRES_METRICSENTRY']._serialized_start=2125 + _globals['_CLIENTMESSAGE_FITRES_METRICSENTRY']._serialized_end=2191 + _globals['_CLIENTMESSAGE_EVALUATERES']._serialized_start=2194 + _globals['_CLIENTMESSAGE_EVALUATERES']._serialized_end=2416 + _globals['_CLIENTMESSAGE_EVALUATERES_METRICSENTRY']._serialized_start=2125 + _globals['_CLIENTMESSAGE_EVALUATERES_METRICSENTRY']._serialized_end=2191 + _globals['_SCALAR']._serialized_start=2425 + _globals['_SCALAR']._serialized_end=2530 + _globals['_FLOWERSERVICE']._serialized_start=2769 + _globals['_FLOWERSERVICE']._serialized_end=2852 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/server/__init__.py b/src/py/flwr/server/__init__.py index 74abe8dd463c..633bd668b520 100644 --- a/src/py/flwr/server/__init__.py +++ b/src/py/flwr/server/__init__.py @@ -16,25 +16,37 @@ from . import strategy -from .app import ServerConfig as ServerConfig +from . import workflow as workflow from .app import run_driver_api as run_driver_api from .app import run_fleet_api as run_fleet_api -from .app import run_server as run_server +from .app import run_superlink as run_superlink from .app import start_server as start_server from .client_manager import ClientManager as ClientManager from .client_manager import SimpleClientManager as SimpleClientManager +from .compat import LegacyContext as LegacyContext +from .compat import start_driver as start_driver +from .driver import Driver as Driver from .history import History as History +from .run_serverapp import run_server_app as run_server_app from .server import Server as Server +from .server_app import ServerApp as ServerApp +from .server_config import ServerConfig as ServerConfig __all__ = [ "ClientManager", + "Driver", "History", + "LegacyContext", "run_driver_api", "run_fleet_api", - "run_server", + "run_server_app", + "run_superlink", "Server", + "ServerApp", "ServerConfig", "SimpleClientManager", + "start_driver", "start_server", "strategy", + "workflow", ] diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index 63c24c37a685..e04cfb37e118 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -14,17 +14,14 @@ # ============================================================================== """Flower server app.""" - import argparse +import asyncio import importlib.util import sys import threading -from dataclasses import dataclass from logging import ERROR, INFO, WARN from os.path import isfile from pathlib import Path -from signal import SIGINT, SIGTERM, signal -from types import FrameType from typing import List, Optional, Tuple import grpc @@ -33,27 +30,29 @@ from flwr.common.address import parse_address from flwr.common.constant import ( MISSING_EXTRA_REST, - TRANSPORT_TYPE_GRPC_BIDI, TRANSPORT_TYPE_GRPC_RERE, TRANSPORT_TYPE_REST, + TRANSPORT_TYPE_VCE, ) +from flwr.common.exit_handlers import register_exit_handlers from flwr.common.logger import log -from flwr.proto.driver_pb2_grpc import add_DriverServicer_to_server -from flwr.proto.fleet_pb2_grpc import add_FleetServicer_to_server -from flwr.proto.transport_pb2_grpc import add_FlowerServiceServicer_to_server -from flwr.server.client_manager import ClientManager, SimpleClientManager -from flwr.server.driver.driver_servicer import DriverServicer -from flwr.server.fleet.grpc_bidi.driver_client_manager import DriverClientManager -from flwr.server.fleet.grpc_bidi.flower_service_servicer import FlowerServiceServicer -from flwr.server.fleet.grpc_bidi.grpc_server import ( +from flwr.proto.fleet_pb2_grpc import ( # pylint: disable=E0611 + add_FleetServicer_to_server, +) + +from .client_manager import ClientManager +from .history import History +from .server import Server, init_defaults, run_fl +from .server_config import ServerConfig +from .strategy import Strategy +from .superlink.driver.driver_grpc import run_driver_api_grpc +from .superlink.fleet.grpc_bidi.grpc_server import ( generic_create_grpc_server, start_grpc_server, ) -from flwr.server.fleet.grpc_rere.fleet_servicer import FleetServicer -from flwr.server.history import History -from flwr.server.server import Server -from flwr.server.state import StateFactory -from flwr.server.strategy import FedAvg, Strategy +from .superlink.fleet.grpc_rere.fleet_servicer import FleetServicer +from .superlink.fleet.vce import start_vce +from .superlink.state import StateFactory ADDRESS_DRIVER_API = "0.0.0.0:9091" ADDRESS_FLEET_API_GRPC_RERE = "0.0.0.0:9092" @@ -63,18 +62,6 @@ DATABASE = ":flwr-in-memory-state:" -@dataclass -class ServerConfig: - """Flower server config. - - All attributes have default values which allows users to configure just the ones - they care about. - """ - - num_rounds: int = 1 - round_timeout: Optional[float] = None - - def start_server( # pylint: disable=too-many-arguments,too-many-locals *, server_address: str = ADDRESS_FLEET_API_GRPC_BIDI, @@ -194,52 +181,11 @@ def start_server( # pylint: disable=too-many-arguments,too-many-locals return hist -def init_defaults( - server: Optional[Server], - config: Optional[ServerConfig], - strategy: Optional[Strategy], - client_manager: Optional[ClientManager], -) -> Tuple[Server, ServerConfig]: - """Create server instance if none was given.""" - if server is None: - if client_manager is None: - client_manager = SimpleClientManager() - if strategy is None: - strategy = FedAvg() - server = Server(client_manager=client_manager, strategy=strategy) - elif strategy is not None: - log(WARN, "Both server and strategy were provided, ignoring strategy") - - # Set default config values - if config is None: - config = ServerConfig() - - return server, config - - -def run_fl( - server: Server, - config: ServerConfig, -) -> History: - """Train a model on the given server and return the History object.""" - hist = server.fit(num_rounds=config.num_rounds, timeout=config.round_timeout) - log(INFO, "app_fit: losses_distributed %s", str(hist.losses_distributed)) - log(INFO, "app_fit: metrics_distributed_fit %s", str(hist.metrics_distributed_fit)) - log(INFO, "app_fit: metrics_distributed %s", str(hist.metrics_distributed)) - log(INFO, "app_fit: losses_centralized %s", str(hist.losses_centralized)) - log(INFO, "app_fit: metrics_centralized %s", str(hist.metrics_centralized)) - - # Graceful shutdown - server.disconnect_all_clients(timeout=config.round_timeout) - - return hist - - def run_driver_api() -> None: """Run Flower server (Driver API).""" log(INFO, "Starting Flower server (Driver API)") event(EventType.RUN_DRIVER_API_ENTER) - args = _parse_args_driver().parse_args() + args = _parse_args_run_driver_api().parse_args() # Parse IP address parsed_address = parse_address(args.driver_api_address) @@ -255,17 +201,17 @@ def run_driver_api() -> None: state_factory = StateFactory(args.database) # Start server - grpc_server: grpc.Server = _run_driver_api_grpc( + grpc_server: grpc.Server = run_driver_api_grpc( address=address, state_factory=state_factory, certificates=certificates, ) # Graceful shutdown - _register_exit_handlers( + register_exit_handlers( + event_type=EventType.RUN_DRIVER_API_LEAVE, grpc_servers=[grpc_server], bckg_threads=[], - event_type=EventType.RUN_DRIVER_API_LEAVE, ) # Block @@ -276,7 +222,7 @@ def run_fleet_api() -> None: """Run Flower server (Fleet API).""" log(INFO, "Starting Flower server (Fleet API)") event(EventType.RUN_FLEET_API_ENTER) - args = _parse_args_fleet().parse_args() + args = _parse_args_run_fleet_api().parse_args() # Obtain certificates certificates = _try_obtain_certificates(args) @@ -313,19 +259,6 @@ def run_fleet_api() -> None: ) fleet_thread.start() bckg_threads.append(fleet_thread) - elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_BIDI: - address_arg = args.grpc_fleet_api_address - parsed_address = parse_address(address_arg) - if not parsed_address: - sys.exit(f"Fleet IP address ({address_arg}) cannot be parsed.") - host, port, is_v6 = parsed_address - address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}" - fleet_server = _run_fleet_api_grpc_bidi( - address=address, - state_factory=state_factory, - certificates=certificates, - ) - grpc_servers.append(fleet_server) elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE: address_arg = args.grpc_rere_fleet_api_address parsed_address = parse_address(address_arg) @@ -343,10 +276,10 @@ def run_fleet_api() -> None: raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}") # Graceful shutdown - _register_exit_handlers( + register_exit_handlers( + event_type=EventType.RUN_FLEET_API_LEAVE, grpc_servers=grpc_servers, bckg_threads=bckg_threads, - event_type=EventType.RUN_FLEET_API_LEAVE, ) # Block @@ -357,11 +290,11 @@ def run_fleet_api() -> None: # pylint: disable=too-many-branches, too-many-locals, too-many-statements -def run_server() -> None: +def run_superlink() -> None: """Run Flower server (Driver API and Fleet API).""" log(INFO, "Starting Flower server") - event(EventType.RUN_SERVER_ENTER) - args = _parse_args_server().parse_args() + event(EventType.RUN_SUPERLINK_ENTER) + args = _parse_args_run_superlink().parse_args() # Parse IP address parsed_address = parse_address(args.driver_api_address) @@ -377,7 +310,7 @@ def run_server() -> None: state_factory = StateFactory(args.database) # Start Driver API - driver_server: grpc.Server = _run_driver_api_grpc( + driver_server: grpc.Server = run_driver_api_grpc( address=address, state_factory=state_factory, certificates=certificates, @@ -412,19 +345,6 @@ def run_server() -> None: ) fleet_thread.start() bckg_threads.append(fleet_thread) - elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_BIDI: - address_arg = args.grpc_bidi_fleet_api_address - parsed_address = parse_address(address_arg) - if not parsed_address: - sys.exit(f"Fleet IP address ({address_arg}) cannot be parsed.") - host, port, is_v6 = parsed_address - address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}" - fleet_server = _run_fleet_api_grpc_bidi( - address=address, - state_factory=state_factory, - certificates=certificates, - ) - grpc_servers.append(fleet_server) elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE: address_arg = args.grpc_rere_fleet_api_address parsed_address = parse_address(address_arg) @@ -438,14 +358,25 @@ def run_server() -> None: certificates=certificates, ) grpc_servers.append(fleet_server) + elif args.fleet_api_type == TRANSPORT_TYPE_VCE: + f_stop = asyncio.Event() # Does nothing + _run_fleet_api_vce( + num_supernodes=args.num_supernodes, + client_app_attr=args.client_app, + backend_name=args.backend, + backend_config_json_stream=args.backend_config, + app_dir=args.app_dir, + state_factory=state_factory, + f_stop=f_stop, + ) else: raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}") # Graceful shutdown - _register_exit_handlers( + register_exit_handlers( + event_type=EventType.RUN_SUPERLINK_LEAVE, grpc_servers=grpc_servers, bckg_threads=bckg_threads, - event_type=EventType.RUN_SERVER_LEAVE, ) # Block @@ -480,105 +411,6 @@ def _try_obtain_certificates( return certificates -def _register_exit_handlers( - grpc_servers: List[grpc.Server], - bckg_threads: List[threading.Thread], - event_type: EventType, -) -> None: - default_handlers = { - SIGINT: None, - SIGTERM: None, - } - - def graceful_exit_handler( # type: ignore - signalnum, - frame: FrameType, # pylint: disable=unused-argument - ) -> None: - """Exit handler to be registered with signal.signal. - - When called will reset signal handler to original signal handler from - default_handlers. - """ - # Reset to default handler - signal(signalnum, default_handlers[signalnum]) - - event_res = event(event_type=event_type) - - for grpc_server in grpc_servers: - grpc_server.stop(grace=1) - - for bckg_thread in bckg_threads: - bckg_thread.join() - - # Ensure event has happend - event_res.result() - - # Setup things for graceful exit - sys.exit(0) - - default_handlers[SIGINT] = signal( # type: ignore - SIGINT, - graceful_exit_handler, # type: ignore - ) - default_handlers[SIGTERM] = signal( # type: ignore - SIGTERM, - graceful_exit_handler, # type: ignore - ) - - -def _run_driver_api_grpc( - address: str, - state_factory: StateFactory, - certificates: Optional[Tuple[bytes, bytes, bytes]], -) -> grpc.Server: - """Run Driver API (gRPC, request-response).""" - # Create Driver API gRPC server - driver_servicer: grpc.Server = DriverServicer( - state_factory=state_factory, - ) - driver_add_servicer_to_server_fn = add_DriverServicer_to_server - driver_grpc_server = generic_create_grpc_server( - servicer_and_add_fn=(driver_servicer, driver_add_servicer_to_server_fn), - server_address=address, - max_message_length=GRPC_MAX_MESSAGE_LENGTH, - certificates=certificates, - ) - - log(INFO, "Flower ECE: Starting Driver API (gRPC-rere) on %s", address) - driver_grpc_server.start() - - return driver_grpc_server - - -def _run_fleet_api_grpc_bidi( - address: str, - state_factory: StateFactory, - certificates: Optional[Tuple[bytes, bytes, bytes]], -) -> grpc.Server: - """Run Fleet API (gRPC, bidirectional streaming).""" - # DriverClientManager - driver_client_manager = DriverClientManager( - state_factory=state_factory, - ) - - # Create (legacy) Fleet API gRPC server - fleet_servicer = FlowerServiceServicer( - client_manager=driver_client_manager, - ) - fleet_add_servicer_to_server_fn = add_FlowerServiceServicer_to_server - fleet_grpc_server = generic_create_grpc_server( - servicer_and_add_fn=(fleet_servicer, fleet_add_servicer_to_server_fn), - server_address=address, - max_message_length=GRPC_MAX_MESSAGE_LENGTH, - certificates=certificates, - ) - - log(INFO, "Flower ECE: Starting Fleet API (gRPC-bidi) on %s", address) - fleet_grpc_server.start() - - return fleet_grpc_server - - def _run_fleet_api_grpc_rere( address: str, state_factory: StateFactory, @@ -587,7 +419,7 @@ def _run_fleet_api_grpc_rere( """Run Fleet API (gRPC, request-response).""" # Create Fleet API gRPC server fleet_servicer = FleetServicer( - state=state_factory.state(), + state_factory=state_factory, ) fleet_add_servicer_to_server_fn = add_FleetServicer_to_server fleet_grpc_server = generic_create_grpc_server( @@ -603,6 +435,29 @@ def _run_fleet_api_grpc_rere( return fleet_grpc_server +# pylint: disable=too-many-arguments +def _run_fleet_api_vce( + num_supernodes: int, + client_app_attr: str, + backend_name: str, + backend_config_json_stream: str, + app_dir: str, + state_factory: StateFactory, + f_stop: asyncio.Event, +) -> None: + log(INFO, "Flower VCE: Starting Fleet API (VirtualClientEngine)") + + start_vce( + num_supernodes=num_supernodes, + client_app_attr=client_app_attr, + backend_name=backend_name, + backend_config_json_stream=backend_config_json_stream, + state_factory=state_factory, + app_dir=app_dir, + f_stop=f_stop, + ) + + # pylint: disable=import-outside-toplevel,too-many-arguments def _run_fleet_api_rest( host: str, @@ -616,7 +471,7 @@ def _run_fleet_api_rest( try: import uvicorn - from flwr.server.fleet.rest_rere.rest_api import app as fast_api_app + from flwr.server.superlink.fleet.rest_rere.rest_api import app as fast_api_app except ModuleNotFoundError: sys.exit(MISSING_EXTRA_REST) if workers != 1: @@ -639,7 +494,7 @@ def _run_fleet_api_rest( raise ValueError(validation_exceptions) uvicorn.run( - app="flwr.server.fleet.rest_rere.rest_api:app", + app="flwr.server.superlink.fleet.rest_rere.rest_api:app", port=port, host=host, reload=False, @@ -676,7 +531,7 @@ def _validate_ssl_files( return validation_exceptions -def _parse_args_driver() -> argparse.ArgumentParser: +def _parse_args_run_driver_api() -> argparse.ArgumentParser: """Parse command line arguments for Driver API.""" parser = argparse.ArgumentParser( description="Start a Flower Driver API server. " @@ -693,7 +548,7 @@ def _parse_args_driver() -> argparse.ArgumentParser: return parser -def _parse_args_fleet() -> argparse.ArgumentParser: +def _parse_args_run_fleet_api() -> argparse.ArgumentParser: """Parse command line arguments for Fleet API.""" parser = argparse.ArgumentParser( description="Start a Flower Fleet API server." @@ -710,7 +565,7 @@ def _parse_args_fleet() -> argparse.ArgumentParser: return parser -def _parse_args_server() -> argparse.ArgumentParser: +def _parse_args_run_superlink() -> argparse.ArgumentParser: """Parse command line arguments for both Driver API and Fleet API.""" parser = argparse.ArgumentParser( description="This will start a Flower server " @@ -779,12 +634,13 @@ def _add_args_fleet_api(parser: argparse.ArgumentParser) -> None: const=TRANSPORT_TYPE_REST, help="Start a Fleet API server (REST, experimental)", ) + ex_group.add_argument( - "--grpc-bidi", + "--vce", action="store_const", dest="fleet_api_type", - const=TRANSPORT_TYPE_GRPC_BIDI, - help="Start a Fleet API server (gRPC-bidi)", + const=TRANSPORT_TYPE_VCE, + help="Start a Fleet API server (VirtualClientEngine)", ) # Fleet API gRPC-rere options @@ -823,12 +679,35 @@ def _add_args_fleet_api(parser: argparse.ArgumentParser) -> None: default=1, ) - # Fleet API gRPC-bidi options - grpc_bidi_group = parser.add_argument_group( - "Fleet API (gRPC-bidi) server options", "" + # Fleet API VCE options + vce_group = parser.add_argument_group("Fleet API (VCE) server options", "") + vce_group.add_argument( + "--client-app", + help="For example: `client:app` or `project.package.module:wrapper.app`.", ) - grpc_bidi_group.add_argument( - "--grpc-bidi-fleet-api-address", - help="Fleet API (gRPC-bidi) server address (IPv4, IPv6, or a domain name)", - default=ADDRESS_FLEET_API_GRPC_RERE, + vce_group.add_argument( + "--num-supernodes", + type=int, + help="Number of simulated SuperNodes.", + ) + vce_group.add_argument( + "--backend", + default="ray", + type=str, + help="Simulation backend that executes the ClientApp.", + ) + vce_group.add_argument( + "--backend-config", + type=str, + default='{"client_resources": {"num_cpus":1, "num_gpus":0.0}, "tensorflow": 0}', + help='A JSON formatted stream, e.g \'{"":, "":}\' to ' + "configure a backend. Values supported in are those included by " + "`flwr.common.typing.ConfigsRecordValues`. ", + ) + parser.add_argument( + "--app-dir", + default="", + help="Add specified directory to the PYTHONPATH and load" + "ClientApp from there." + " Default: current working directory.", ) diff --git a/src/py/flwr/server/client_manager_test.py b/src/py/flwr/server/client_manager_test.py index 8145b9b2ab7f..5820881b6aad 100644 --- a/src/py/flwr/server/client_manager_test.py +++ b/src/py/flwr/server/client_manager_test.py @@ -18,7 +18,7 @@ from unittest.mock import MagicMock from flwr.server.client_manager import SimpleClientManager -from flwr.server.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy +from flwr.server.superlink.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy def test_simple_client_manager_register() -> None: diff --git a/src/py/flwr/server/client_proxy.py b/src/py/flwr/server/client_proxy.py index 7d0547be304b..951e4ae992da 100644 --- a/src/py/flwr/server/client_proxy.py +++ b/src/py/flwr/server/client_proxy.py @@ -47,6 +47,7 @@ def get_properties( self, ins: GetPropertiesIns, timeout: Optional[float], + group_id: Optional[int], ) -> GetPropertiesRes: """Return the client's properties.""" @@ -55,6 +56,7 @@ def get_parameters( self, ins: GetParametersIns, timeout: Optional[float], + group_id: Optional[int], ) -> GetParametersRes: """Return the current local model parameters.""" @@ -63,6 +65,7 @@ def fit( self, ins: FitIns, timeout: Optional[float], + group_id: Optional[int], ) -> FitRes: """Refine the provided parameters using the locally held dataset.""" @@ -71,6 +74,7 @@ def evaluate( self, ins: EvaluateIns, timeout: Optional[float], + group_id: Optional[int], ) -> EvaluateRes: """Evaluate the provided parameters using the locally held dataset.""" @@ -79,5 +83,6 @@ def reconnect( self, ins: ReconnectIns, timeout: Optional[float], + group_id: Optional[int], ) -> DisconnectRes: """Disconnect and (optionally) reconnect later.""" diff --git a/src/py/flwr/server/client_proxy_test.py b/src/py/flwr/server/client_proxy_test.py index 266e4cbeb266..685698558e3a 100644 --- a/src/py/flwr/server/client_proxy_test.py +++ b/src/py/flwr/server/client_proxy_test.py @@ -42,6 +42,7 @@ def get_properties( self, ins: GetPropertiesIns, timeout: Optional[float], + group_id: Optional[int], ) -> GetPropertiesRes: """Return the client's properties.""" return GetPropertiesRes(status=Status(code=Code.OK, message=""), properties={}) @@ -50,6 +51,7 @@ def get_parameters( self, ins: GetParametersIns, timeout: Optional[float], + group_id: Optional[int], ) -> GetParametersRes: """Return the current local model parameters.""" return GetParametersRes( @@ -61,6 +63,7 @@ def fit( self, ins: FitIns, timeout: Optional[float], + group_id: Optional[int], ) -> FitRes: """Refine the provided weights using the locally held dataset.""" return FitRes( @@ -74,6 +77,7 @@ def evaluate( self, ins: EvaluateIns, timeout: Optional[float], + group_id: Optional[int], ) -> EvaluateRes: """Evaluate the provided weights using the locally held dataset.""" return EvaluateRes( @@ -84,6 +88,7 @@ def reconnect( self, ins: ReconnectIns, timeout: Optional[float], + group_id: Optional[int], ) -> DisconnectRes: """Disconnect and (optionally) reconnect later.""" return DisconnectRes(reason="") diff --git a/src/py/flwr/driver/__init__.py b/src/py/flwr/server/compat/__init__.py similarity index 74% rename from src/py/flwr/driver/__init__.py rename to src/py/flwr/server/compat/__init__.py index 1c3b09cc334b..7bae196ddb65 100644 --- a/src/py/flwr/driver/__init__.py +++ b/src/py/flwr/server/compat/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,15 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Flower driver SDK.""" +"""Flower ServerApp compatibility package.""" -from .app import start_driver -from .driver import Driver -from .grpc_driver import GrpcDriver +from .app import start_driver as start_driver +from .legacy_context import LegacyContext as LegacyContext __all__ = [ - "Driver", - "GrpcDriver", + "LegacyContext", "start_driver", ] diff --git a/src/py/flwr/driver/app.py b/src/py/flwr/server/compat/app.py similarity index 51% rename from src/py/flwr/driver/app.py rename to src/py/flwr/server/compat/app.py index 3cb8652365d8..ff1d99b5366e 100644 --- a/src/py/flwr/driver/app.py +++ b/src/py/flwr/server/compat/app.py @@ -16,24 +16,21 @@ import sys -import threading -import time from logging import INFO from pathlib import Path -from typing import Dict, Optional, Union +from typing import Optional, Union from flwr.common import EventType, event from flwr.common.address import parse_address -from flwr.common.logger import log -from flwr.proto import driver_pb2 -from flwr.server.app import ServerConfig, init_defaults, run_fl +from flwr.common.logger import log, warn_deprecated_feature from flwr.server.client_manager import ClientManager from flwr.server.history import History -from flwr.server.server import Server +from flwr.server.server import Server, init_defaults, run_fl +from flwr.server.server_config import ServerConfig from flwr.server.strategy import Strategy -from .driver_client_proxy import DriverClientProxy -from .grpc_driver import GrpcDriver +from ..driver import Driver +from .app_utils import start_update_client_manager_thread DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091" @@ -53,6 +50,7 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals strategy: Optional[Strategy] = None, client_manager: Optional[ClientManager] = None, root_certificates: Optional[Union[bytes, str]] = None, + driver: Optional[Driver] = None, ) -> History: """Start a Flower Driver API server. @@ -72,15 +70,16 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals An implementation of the abstract base class `flwr.server.strategy.Strategy`. If no strategy is provided, then `start_server` will use `flwr.server.strategy.FedAvg`. - client_manager : Optional[flwr.driver.DriverClientManager] (default: None) - An implementation of the class - `flwr.driver.driver_client_manager.DriverClientManager`. If no + client_manager : Optional[flwr.server.ClientManager] (default: None) + An implementation of the class `flwr.server.ClientManager`. If no implementation is provided, then `start_driver` will use - `flwr.driver.driver_client_manager.DriverClientManager`. + `flwr.server.SimpleClientManager`. root_certificates : Optional[Union[bytes, str]] (default: None) The PEM-encoded root certificates as a byte string or a path string. If provided, a secure connection using the certificates will be established to an SSL-enabled Flower server. + driver : Optional[Driver] (default: None) + The Driver object to use. Returns ------- @@ -101,19 +100,23 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals """ event(EventType.START_DRIVER_ENTER) - # Parse IP address - parsed_address = parse_address(server_address) - if not parsed_address: - sys.exit(f"Server IP address ({server_address}) cannot be parsed.") - host, port, is_v6 = parsed_address - address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}" + if driver is None: + # Not passing a `Driver` object is deprecated + warn_deprecated_feature("start_driver") - # Create the Driver - if isinstance(root_certificates, str): - root_certificates = Path(root_certificates).read_bytes() - driver = GrpcDriver(driver_service_address=address, certificates=root_certificates) - driver.connect() - lock = threading.Lock() + # Parse IP address + parsed_address = parse_address(server_address) + if not parsed_address: + sys.exit(f"Server IP address ({server_address}) cannot be parsed.") + host, port, is_v6 = parsed_address + address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}" + + # Create the Driver + if isinstance(root_certificates, str): + root_certificates = Path(root_certificates).read_bytes() + driver = Driver( + driver_service_address=address, root_certificates=root_certificates + ) # Initialize the Driver API server and config initialized_server, initialized_config = init_defaults( @@ -124,20 +127,15 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals ) log( INFO, - "Starting Flower server, config: %s", + "Starting Flower ServerApp, config: %s", initialized_config, ) + log(INFO, "") # Start the thread updating nodes - thread = threading.Thread( - target=update_client_manager, - args=( - driver, - initialized_server.client_manager(), - lock, - ), + thread, f_stop = start_update_client_manager_thread( + driver, initialized_server.client_manager() ) - thread.start() # Start training hist = run_fl( @@ -145,66 +143,10 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals config=initialized_config, ) - # Stop the Driver API server and the thread - with lock: - driver.disconnect() + # Terminate the thread + f_stop.set() thread.join() event(EventType.START_SERVER_LEAVE) return hist - - -def update_client_manager( - driver: GrpcDriver, - client_manager: ClientManager, - lock: threading.Lock, -) -> None: - """Update the nodes list in the client manager. - - This function periodically communicates with the associated driver to get all - node_ids. Each node_id is then converted into a `DriverClientProxy` instance - and stored in the `registered_nodes` dictionary with node_id as key. - - New nodes will be added to the ClientManager via `client_manager.register()`, - and dead nodes will be removed from the ClientManager via - `client_manager.unregister()`. - """ - # Request for workload_id - workload_id = driver.create_workload(driver_pb2.CreateWorkloadRequest()).workload_id - - # Loop until the driver is disconnected - registered_nodes: Dict[int, DriverClientProxy] = {} - while True: - with lock: - # End the while loop if the driver is disconnected - if driver.stub is None: - break - get_nodes_res = driver.get_nodes( - req=driver_pb2.GetNodesRequest(workload_id=workload_id) - ) - all_node_ids = {node.node_id for node in get_nodes_res.nodes} - dead_nodes = set(registered_nodes).difference(all_node_ids) - new_nodes = all_node_ids.difference(registered_nodes) - - # Unregister dead nodes - for node_id in dead_nodes: - client_proxy = registered_nodes[node_id] - client_manager.unregister(client_proxy) - del registered_nodes[node_id] - - # Register new nodes - for node_id in new_nodes: - client_proxy = DriverClientProxy( - node_id=node_id, - driver=driver, - anonymous=False, - workload_id=workload_id, - ) - if client_manager.register(client_proxy): - registered_nodes[node_id] = client_proxy - else: - raise RuntimeError("Could not register node.") - - # Sleep for 3 seconds - time.sleep(3) diff --git a/src/py/flwr/server/compat/app_utils.py b/src/py/flwr/server/compat/app_utils.py new file mode 100644 index 000000000000..696ec1132c4a --- /dev/null +++ b/src/py/flwr/server/compat/app_utils.py @@ -0,0 +1,102 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility functions for the `start_driver`.""" + + +import threading +import time +from typing import Dict, Tuple + +from ..client_manager import ClientManager +from ..compat.driver_client_proxy import DriverClientProxy +from ..driver import Driver + + +def start_update_client_manager_thread( + driver: Driver, + client_manager: ClientManager, +) -> Tuple[threading.Thread, threading.Event]: + """Periodically update the nodes list in the client manager in a thread. + + This function starts a thread that periodically uses the associated driver to + get all node_ids. Each node_id is then converted into a `DriverClientProxy` + instance and stored in the `registered_nodes` dictionary with node_id as key. + + New nodes will be added to the ClientManager via `client_manager.register()`, + and dead nodes will be removed from the ClientManager via + `client_manager.unregister()`. + + Parameters + ---------- + driver : Driver + The Driver object to use. + client_manager : ClientManager + The ClientManager object to be updated. + + Returns + ------- + threading.Thread + A thread that updates the ClientManager and handles the stop event. + threading.Event + An event that, when set, signals the thread to stop. + """ + f_stop = threading.Event() + thread = threading.Thread( + target=_update_client_manager, + args=( + driver, + client_manager, + f_stop, + ), + ) + thread.start() + + return thread, f_stop + + +def _update_client_manager( + driver: Driver, + client_manager: ClientManager, + f_stop: threading.Event, +) -> None: + """Update the nodes list in the client manager.""" + # Loop until the driver is disconnected + registered_nodes: Dict[int, DriverClientProxy] = {} + while not f_stop.is_set(): + all_node_ids = set(driver.get_node_ids()) + dead_nodes = set(registered_nodes).difference(all_node_ids) + new_nodes = all_node_ids.difference(registered_nodes) + + # Unregister dead nodes + for node_id in dead_nodes: + client_proxy = registered_nodes[node_id] + client_manager.unregister(client_proxy) + del registered_nodes[node_id] + + # Register new nodes + for node_id in new_nodes: + client_proxy = DriverClientProxy( + node_id=node_id, + driver=driver.grpc_driver, # type: ignore + anonymous=False, + run_id=driver.run_id, # type: ignore + ) + if client_manager.register(client_proxy): + registered_nodes[node_id] = client_proxy + else: + raise RuntimeError("Could not register node.") + + # Sleep for 3 seconds + time.sleep(3) diff --git a/src/py/flwr/server/compat/app_utils_test.py b/src/py/flwr/server/compat/app_utils_test.py new file mode 100644 index 000000000000..7e47e6eaaf32 --- /dev/null +++ b/src/py/flwr/server/compat/app_utils_test.py @@ -0,0 +1,62 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for utility functions for the `start_driver`.""" + + +import time +import unittest +from unittest.mock import Mock, patch + +from ..client_manager import SimpleClientManager +from .app_utils import start_update_client_manager_thread + + +class TestUtils(unittest.TestCase): + """Tests for utility functions.""" + + def test_start_update_client_manager_thread(self) -> None: + """Test start_update_client_manager_thread function.""" + # Prepare + sleep = time.sleep + sleep_patch = patch("time.sleep", lambda x: sleep(x / 100)) + sleep_patch.start() + expected_node_ids = list(range(100)) + updated_expected_node_ids = list(range(80, 120)) + driver = Mock() + driver.grpc_driver = Mock() + driver.run_id = 123 + driver.get_node_ids.return_value = expected_node_ids + client_manager = SimpleClientManager() + + # Execute + thread, f_stop = start_update_client_manager_thread(driver, client_manager) + # Wait until all nodes are registered via `client_manager.sample()` + client_manager.sample(len(expected_node_ids)) + # Retrieve all nodes in `client_manager` + node_ids = {proxy.node_id for proxy in client_manager.all().values()} + # Update the GetNodesResponse and wait until the `client_manager` is updated + driver.get_node_ids.return_value = updated_expected_node_ids + sleep(0.1) + # Retrieve all nodes in `client_manager` + updated_node_ids = {proxy.node_id for proxy in client_manager.all().values()} + # Stop the thread + f_stop.set() + + # Assert + assert node_ids == set(expected_node_ids) + assert updated_node_ids == set(updated_expected_node_ids) + + # Exit + thread.join() diff --git a/src/py/flwr/server/compat/driver_client_proxy.py b/src/py/flwr/server/compat/driver_client_proxy.py new file mode 100644 index 000000000000..84c67149fad7 --- /dev/null +++ b/src/py/flwr/server/compat/driver_client_proxy.py @@ -0,0 +1,169 @@ +# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower ClientProxy implementation for Driver API.""" + + +import time +from typing import List, Optional + +from flwr import common +from flwr.common import MessageType, MessageTypeLegacy, RecordSet +from flwr.common import recordset_compat as compat +from flwr.common import serde +from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611 +from flwr.server.client_proxy import ClientProxy + +from ..driver.grpc_driver import GrpcDriver + +SLEEP_TIME = 1 + + +class DriverClientProxy(ClientProxy): + """Flower client proxy which delegates work using the Driver API.""" + + def __init__(self, node_id: int, driver: GrpcDriver, anonymous: bool, run_id: int): + super().__init__(str(node_id)) + self.node_id = node_id + self.driver = driver + self.run_id = run_id + self.anonymous = anonymous + + def get_properties( + self, + ins: common.GetPropertiesIns, + timeout: Optional[float], + group_id: Optional[int], + ) -> common.GetPropertiesRes: + """Return client's properties.""" + # Ins to RecordSet + out_recordset = compat.getpropertiesins_to_recordset(ins) + # Fetch response + in_recordset = self._send_receive_recordset( + out_recordset, MessageTypeLegacy.GET_PROPERTIES, timeout, group_id + ) + # RecordSet to Res + return compat.recordset_to_getpropertiesres(in_recordset) + + def get_parameters( + self, + ins: common.GetParametersIns, + timeout: Optional[float], + group_id: Optional[int], + ) -> common.GetParametersRes: + """Return the current local model parameters.""" + # Ins to RecordSet + out_recordset = compat.getparametersins_to_recordset(ins) + # Fetch response + in_recordset = self._send_receive_recordset( + out_recordset, MessageTypeLegacy.GET_PARAMETERS, timeout, group_id + ) + # RecordSet to Res + return compat.recordset_to_getparametersres(in_recordset, False) + + def fit( + self, ins: common.FitIns, timeout: Optional[float], group_id: Optional[int] + ) -> common.FitRes: + """Train model parameters on the locally held dataset.""" + # Ins to RecordSet + out_recordset = compat.fitins_to_recordset(ins, keep_input=True) + # Fetch response + in_recordset = self._send_receive_recordset( + out_recordset, MessageType.TRAIN, timeout, group_id + ) + # RecordSet to Res + return compat.recordset_to_fitres(in_recordset, keep_input=False) + + def evaluate( + self, ins: common.EvaluateIns, timeout: Optional[float], group_id: Optional[int] + ) -> common.EvaluateRes: + """Evaluate model parameters on the locally held dataset.""" + # Ins to RecordSet + out_recordset = compat.evaluateins_to_recordset(ins, keep_input=True) + # Fetch response + in_recordset = self._send_receive_recordset( + out_recordset, MessageType.EVALUATE, timeout, group_id + ) + # RecordSet to Res + return compat.recordset_to_evaluateres(in_recordset) + + def reconnect( + self, + ins: common.ReconnectIns, + timeout: Optional[float], + group_id: Optional[int], + ) -> common.DisconnectRes: + """Disconnect and (optionally) reconnect later.""" + return common.DisconnectRes(reason="") # Nothing to do here (yet) + + def _send_receive_recordset( + self, + recordset: RecordSet, + task_type: str, + timeout: Optional[float], + group_id: Optional[int], + ) -> RecordSet: + task_ins = task_pb2.TaskIns( # pylint: disable=E1101 + task_id="", + group_id=str(group_id) if group_id is not None else "", + run_id=self.run_id, + task=task_pb2.Task( # pylint: disable=E1101 + producer=node_pb2.Node( # pylint: disable=E1101 + node_id=0, + anonymous=True, + ), + consumer=node_pb2.Node( # pylint: disable=E1101 + node_id=self.node_id, + anonymous=self.anonymous, + ), + task_type=task_type, + recordset=serde.recordset_to_proto(recordset), + ), + ) + push_task_ins_req = driver_pb2.PushTaskInsRequest( # pylint: disable=E1101 + task_ins_list=[task_ins] + ) + + # Send TaskIns to Driver API + push_task_ins_res = self.driver.push_task_ins(req=push_task_ins_req) + + if len(push_task_ins_res.task_ids) != 1: + raise ValueError("Unexpected number of task_ids") + + task_id = push_task_ins_res.task_ids[0] + if task_id == "": + raise ValueError(f"Failed to schedule task for node {self.node_id}") + + if timeout: + start_time = time.time() + + while True: + pull_task_res_req = driver_pb2.PullTaskResRequest( # pylint: disable=E1101 + node=node_pb2.Node(node_id=0, anonymous=True), # pylint: disable=E1101 + task_ids=[task_id], + ) + + # Ask Driver API for TaskRes + pull_task_res_res = self.driver.pull_task_res(req=pull_task_res_req) + + task_res_list: List[task_pb2.TaskRes] = list( # pylint: disable=E1101 + pull_task_res_res.task_res_list + ) + if len(task_res_list) == 1: + task_res = task_res_list[0] + return serde.recordset_from_proto(task_res.task.recordset) + + if timeout is not None and time.time() > start_time + timeout: + raise RuntimeError("Timeout reached") + time.sleep(SLEEP_TIME) diff --git a/src/py/flwr/server/compat/driver_client_proxy_test.py b/src/py/flwr/server/compat/driver_client_proxy_test.py new file mode 100644 index 000000000000..3494049c1064 --- /dev/null +++ b/src/py/flwr/server/compat/driver_client_proxy_test.py @@ -0,0 +1,245 @@ +# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""DriverClientProxy tests.""" + + +import unittest +from typing import Union, cast +from unittest.mock import MagicMock + +import numpy as np + +import flwr +from flwr.common import recordset_compat as compat +from flwr.common import serde +from flwr.common.constant import MessageType, MessageTypeLegacy +from flwr.common.typing import ( + Code, + Config, + EvaluateIns, + EvaluateRes, + FitRes, + GetParametersIns, + GetParametersRes, + GetPropertiesRes, + Parameters, + Properties, + Status, +) +from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611 + +from .driver_client_proxy import DriverClientProxy + +MESSAGE_PARAMETERS = Parameters(tensors=[b"abc"], tensor_type="np") + +CLIENT_PROPERTIES = cast(Properties, {"tensor_type": "numpy.ndarray"}) +CLIENT_STATUS = Status(code=Code.OK, message="OK") + + +def _make_task( + res: Union[GetParametersRes, GetPropertiesRes, FitRes, EvaluateRes] +) -> task_pb2.Task: # pylint: disable=E1101 + if isinstance(res, GetParametersRes): + message_type = MessageTypeLegacy.GET_PARAMETERS + recordset = compat.getparametersres_to_recordset(res, True) + elif isinstance(res, GetPropertiesRes): + message_type = MessageTypeLegacy.GET_PROPERTIES + recordset = compat.getpropertiesres_to_recordset(res) + elif isinstance(res, FitRes): + message_type = MessageType.TRAIN + recordset = compat.fitres_to_recordset(res, True) + elif isinstance(res, EvaluateRes): + message_type = MessageType.EVALUATE + recordset = compat.evaluateres_to_recordset(res) + else: + raise ValueError(f"Unsupported type: {type(res)}") + return task_pb2.Task( # pylint: disable=E1101 + task_type=message_type, + recordset=serde.recordset_to_proto(recordset), + ) + + +class DriverClientProxyTestCase(unittest.TestCase): + """Tests for DriverClientProxy.""" + + def setUp(self) -> None: + """Set up mocks for tests.""" + self.driver = MagicMock() + self.driver.get_nodes.return_value = ( + driver_pb2.GetNodesResponse( # pylint: disable=E1101 + nodes=[ + node_pb2.Node(node_id=1, anonymous=False) # pylint: disable=E1101 + ] + ) + ) + + def test_get_properties(self) -> None: + """Test positive case.""" + # Prepare + self.driver.push_task_ins.return_value = ( + driver_pb2.PushTaskInsResponse( # pylint: disable=E1101 + task_ids=["19341fd7-62e1-4eb4-beb4-9876d3acda32"] + ) + ) + self.driver.pull_task_res.return_value = ( + driver_pb2.PullTaskResResponse( # pylint: disable=E1101 + task_res_list=[ + task_pb2.TaskRes( # pylint: disable=E1101 + task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", + group_id=str(0), + run_id=0, + task=_make_task( + GetPropertiesRes( + status=CLIENT_STATUS, properties=CLIENT_PROPERTIES + ) + ), + ) + ] + ) + ) + client = DriverClientProxy( + node_id=1, driver=self.driver, anonymous=True, run_id=0 + ) + request_properties: Config = {"tensor_type": "str"} + ins: flwr.common.GetPropertiesIns = flwr.common.GetPropertiesIns( + config=request_properties + ) + + # Execute + value: flwr.common.GetPropertiesRes = client.get_properties( + ins, timeout=None, group_id=0 + ) + + # Assert + assert value.properties["tensor_type"] == "numpy.ndarray" + + def test_get_parameters(self) -> None: + """Test positive case.""" + # Prepare + self.driver.push_task_ins.return_value = ( + driver_pb2.PushTaskInsResponse( # pylint: disable=E1101 + task_ids=["19341fd7-62e1-4eb4-beb4-9876d3acda32"] + ) + ) + self.driver.pull_task_res.return_value = ( + driver_pb2.PullTaskResResponse( # pylint: disable=E1101 + task_res_list=[ + task_pb2.TaskRes( # pylint: disable=E1101 + task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", + group_id=str(0), + run_id=0, + task=_make_task( + GetParametersRes( + status=CLIENT_STATUS, + parameters=MESSAGE_PARAMETERS, + ) + ), + ) + ] + ) + ) + client = DriverClientProxy( + node_id=1, driver=self.driver, anonymous=True, run_id=0 + ) + get_parameters_ins = GetParametersIns(config={}) + + # Execute + value: flwr.common.GetParametersRes = client.get_parameters( + ins=get_parameters_ins, timeout=None, group_id=0 + ) + + # Assert + assert value.parameters.tensors[0] == b"abc" + + def test_fit(self) -> None: + """Test positive case.""" + # Prepare + self.driver.push_task_ins.return_value = ( + driver_pb2.PushTaskInsResponse( # pylint: disable=E1101 + task_ids=["19341fd7-62e1-4eb4-beb4-9876d3acda32"] + ) + ) + self.driver.pull_task_res.return_value = ( + driver_pb2.PullTaskResResponse( # pylint: disable=E1101 + task_res_list=[ + task_pb2.TaskRes( # pylint: disable=E1101 + task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", + group_id=str(1), + run_id=0, + task=_make_task( + FitRes( + status=CLIENT_STATUS, + parameters=MESSAGE_PARAMETERS, + num_examples=10, + metrics={}, + ) + ), + ) + ] + ) + ) + client = DriverClientProxy( + node_id=1, driver=self.driver, anonymous=True, run_id=0 + ) + parameters = flwr.common.ndarrays_to_parameters([np.ones((2, 2))]) + ins: flwr.common.FitIns = flwr.common.FitIns(parameters, {}) + + # Execute + fit_res = client.fit(ins=ins, timeout=None, group_id=1) + + # Assert + assert fit_res.parameters.tensor_type == "np" + assert fit_res.parameters.tensors[0] == b"abc" + assert fit_res.num_examples == 10 + + def test_evaluate(self) -> None: + """Test positive case.""" + # Prepare + self.driver.push_task_ins.return_value = ( + driver_pb2.PushTaskInsResponse( # pylint: disable=E1101 + task_ids=["19341fd7-62e1-4eb4-beb4-9876d3acda32"] + ) + ) + self.driver.pull_task_res.return_value = ( + driver_pb2.PullTaskResResponse( # pylint: disable=E1101 + task_res_list=[ + task_pb2.TaskRes( # pylint: disable=E1101 + task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", + group_id=str(1), + run_id=0, + task=_make_task( + EvaluateRes( + status=CLIENT_STATUS, + loss=0.0, + num_examples=0, + metrics={}, + ) + ), + ) + ] + ) + ) + client = DriverClientProxy( + node_id=1, driver=self.driver, anonymous=True, run_id=0 + ) + parameters = Parameters(tensors=[], tensor_type="np") + evaluate_ins = EvaluateIns(parameters, {}) + + # Execute + evaluate_res = client.evaluate(evaluate_ins, timeout=None, group_id=1) + + # Assert + assert 0.0 == evaluate_res.loss + assert 0 == evaluate_res.num_examples diff --git a/src/py/flwr/server/compat/legacy_context.py b/src/py/flwr/server/compat/legacy_context.py new file mode 100644 index 000000000000..0b00c98bb16d --- /dev/null +++ b/src/py/flwr/server/compat/legacy_context.py @@ -0,0 +1,55 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Legacy Context.""" + + +from dataclasses import dataclass +from typing import Optional + +from flwr.common import Context, RecordSet + +from ..client_manager import ClientManager, SimpleClientManager +from ..history import History +from ..server_config import ServerConfig +from ..strategy import FedAvg, Strategy + + +@dataclass +class LegacyContext(Context): + """Legacy Context.""" + + config: ServerConfig + strategy: Strategy + client_manager: ClientManager + history: History + + def __init__( + self, + state: RecordSet, + config: Optional[ServerConfig] = None, + strategy: Optional[Strategy] = None, + client_manager: Optional[ClientManager] = None, + ) -> None: + if config is None: + config = ServerConfig() + if strategy is None: + strategy = FedAvg() + if client_manager is None: + client_manager = SimpleClientManager() + self.config = config + self.strategy = strategy + self.client_manager = client_manager + self.history = History() + super().__init__(state) diff --git a/src/py/flwr/server/criterion_test.py b/src/py/flwr/server/criterion_test.py index a7e5b62b5977..f678825f064e 100644 --- a/src/py/flwr/server/criterion_test.py +++ b/src/py/flwr/server/criterion_test.py @@ -20,7 +20,7 @@ from flwr.server.client_manager import SimpleClientManager from flwr.server.client_proxy import ClientProxy from flwr.server.criterion import Criterion -from flwr.server.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy +from flwr.server.superlink.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy def test_criterion_applied() -> None: diff --git a/src/py/flwr/server/driver/__init__.py b/src/py/flwr/server/driver/__init__.py index 2bfe63e6065f..b61f6eebf6a8 100644 --- a/src/py/flwr/server/driver/__init__.py +++ b/src/py/flwr/server/driver/__init__.py @@ -12,4 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Flower driver service.""" +"""Flower driver SDK.""" + + +from .driver import Driver +from .grpc_driver import GrpcDriver + +__all__ = [ + "Driver", + "GrpcDriver", +] diff --git a/src/py/flwr/server/driver/driver.py b/src/py/flwr/server/driver/driver.py new file mode 100644 index 000000000000..0098e0ce97c2 --- /dev/null +++ b/src/py/flwr/server/driver/driver.py @@ -0,0 +1,256 @@ +# Copyright 2022 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower driver service client.""" + + +import time +from typing import Iterable, List, Optional, Tuple + +from flwr.common import Message, Metadata, RecordSet +from flwr.common.serde import message_from_taskres, message_to_taskins +from flwr.proto.driver_pb2 import ( # pylint: disable=E0611 + CreateRunRequest, + GetNodesRequest, + PullTaskResRequest, + PushTaskInsRequest, +) +from flwr.proto.node_pb2 import Node # pylint: disable=E0611 +from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611 + +from .grpc_driver import DEFAULT_SERVER_ADDRESS_DRIVER, GrpcDriver + + +class Driver: + """`Driver` class provides an interface to the Driver API. + + Parameters + ---------- + driver_service_address : Optional[str] + The IPv4 or IPv6 address of the Driver API server. + Defaults to `"[::]:9091"`. + certificates : bytes (default: None) + Tuple containing root certificate, server certificate, and private key + to start a secure SSL-enabled server. The tuple is expected to have + three bytes elements in the following order: + + * CA certificate. + * server certificate. + * server private key. + """ + + def __init__( + self, + driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER, + root_certificates: Optional[bytes] = None, + ) -> None: + self.addr = driver_service_address + self.root_certificates = root_certificates + self.grpc_driver: Optional[GrpcDriver] = None + self.run_id: Optional[int] = None + self.node = Node(node_id=0, anonymous=True) + + def _get_grpc_driver_and_run_id(self) -> Tuple[GrpcDriver, int]: + # Check if the GrpcDriver is initialized + if self.grpc_driver is None or self.run_id is None: + # Connect and create run + self.grpc_driver = GrpcDriver( + driver_service_address=self.addr, + root_certificates=self.root_certificates, + ) + self.grpc_driver.connect() + res = self.grpc_driver.create_run(CreateRunRequest()) + self.run_id = res.run_id + return self.grpc_driver, self.run_id + + def _check_message(self, message: Message) -> None: + # Check if the message is valid + if not ( + message.metadata.run_id == self.run_id + and message.metadata.src_node_id == self.node.node_id + and message.metadata.message_id == "" + and message.metadata.reply_to_message == "" + ): + raise ValueError(f"Invalid message: {message}") + + def create_message( # pylint: disable=too-many-arguments + self, + content: RecordSet, + message_type: str, + dst_node_id: int, + group_id: str, + ttl: str, + ) -> Message: + """Create a new message with specified parameters. + + This method constructs a new `Message` with given content and metadata. + The `run_id` and `src_node_id` will be set automatically. + + Parameters + ---------- + content : RecordSet + The content for the new message. This holds records that are to be sent + to the destination node. + message_type : str + The type of the message, defining the action to be executed on + the receiving end. + dst_node_id : int + The ID of the destination node to which the message is being sent. + group_id : str + The ID of the group to which this message is associated. In some settings, + this is used as the FL round. + ttl : str + Time-to-live for the round trip of this message, i.e., the time from sending + this message to receiving a reply. It specifies the duration for which the + message and its potential reply are considered valid. + + Returns + ------- + message : Message + A new `Message` instance with the specified content and metadata. + """ + _, run_id = self._get_grpc_driver_and_run_id() + metadata = Metadata( + run_id=run_id, + message_id="", # Will be set by the server + src_node_id=self.node.node_id, + dst_node_id=dst_node_id, + reply_to_message="", + group_id=group_id, + ttl=ttl, + message_type=message_type, + ) + return Message(metadata=metadata, content=content) + + def get_node_ids(self) -> List[int]: + """Get node IDs.""" + grpc_driver, run_id = self._get_grpc_driver_and_run_id() + # Call GrpcDriver method + res = grpc_driver.get_nodes(GetNodesRequest(run_id=run_id)) + return [node.node_id for node in res.nodes] + + def push_messages(self, messages: Iterable[Message]) -> Iterable[str]: + """Push messages to specified node IDs. + + This method takes an iterable of messages and sends each message + to the node specified in `dst_node_id`. + + Parameters + ---------- + messages : Iterable[Message] + An iterable of messages to be sent. + + Returns + ------- + message_ids : Iterable[str] + An iterable of IDs for the messages that were sent, which can be used + to pull replies. + """ + grpc_driver, _ = self._get_grpc_driver_and_run_id() + # Construct TaskIns + task_ins_list: List[TaskIns] = [] + for msg in messages: + # Check message + self._check_message(msg) + # Convert Message to TaskIns + taskins = message_to_taskins(msg) + # Add to list + task_ins_list.append(taskins) + # Call GrpcDriver method + res = grpc_driver.push_task_ins(PushTaskInsRequest(task_ins_list=task_ins_list)) + return list(res.task_ids) + + def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]: + """Pull messages based on message IDs. + + This method is used to collect messages from the SuperLink + that correspond to a set of given message IDs. + + Parameters + ---------- + message_ids : Iterable[str] + An iterable of message IDs for which reply messages are to be retrieved. + + Returns + ------- + messages : Iterable[Message] + An iterable of messages received. + """ + grpc_driver, _ = self._get_grpc_driver_and_run_id() + # Pull TaskRes + res = grpc_driver.pull_task_res( + PullTaskResRequest(node=self.node, task_ids=message_ids) + ) + # Convert TaskRes to Message + msgs = [message_from_taskres(taskres) for taskres in res.task_res_list] + return msgs + + def send_and_receive( + self, + messages: Iterable[Message], + *, + timeout: Optional[float] = None, + ) -> Iterable[Message]: + """Push messages to specified node IDs and pull the reply messages. + + This method sends a list of messages to their destination node IDs and then + waits for the replies. It continues to pull replies until either all + replies are received or the specified timeout duration is exceeded. + + Parameters + ---------- + messages : Iterable[Message] + An iterable of messages to be sent. + timeout : Optional[float] (default: None) + The timeout duration in seconds. If specified, the method will wait for + replies for this duration. If `None`, there is no time limit and the method + will wait until replies for all messages are received. + + Returns + ------- + replies : Iterable[Message] + An iterable of reply messages received from the SuperLink. + + Notes + ----- + This method uses `push_messages` to send the messages and `pull_messages` + to collect the replies. If `timeout` is set, the method may not return + replies for all sent messages. A message remains valid until its TTL, + which is not affected by `timeout`. + """ + # Push messages + msg_ids = set(self.push_messages(messages)) + + # Pull messages + end_time = time.time() + (timeout if timeout is not None else 0.0) + ret: List[Message] = [] + while timeout is None or time.time() < end_time: + res_msgs = self.pull_messages(msg_ids) + ret.extend(res_msgs) + msg_ids.difference_update( + {msg.metadata.reply_to_message for msg in res_msgs} + ) + if len(msg_ids) == 0: + break + # Sleep + time.sleep(3) + return ret + + def close(self) -> None: + """Disconnect from the SuperLink if connected.""" + # Check if GrpcDriver is initialized + if self.grpc_driver is None: + return + # Disconnect + self.grpc_driver.disconnect() diff --git a/src/py/flwr/server/driver/driver_test.py b/src/py/flwr/server/driver/driver_test.py new file mode 100644 index 000000000000..5136f4f90210 --- /dev/null +++ b/src/py/flwr/server/driver/driver_test.py @@ -0,0 +1,219 @@ +# Copyright 2022 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for driver SDK.""" + + +import time +import unittest +from unittest.mock import Mock, patch + +from flwr.common import RecordSet +from flwr.common.message import Error +from flwr.common.serde import error_to_proto, recordset_to_proto +from flwr.proto.driver_pb2 import ( # pylint: disable=E0611 + GetNodesRequest, + PullTaskResRequest, + PushTaskInsRequest, +) +from flwr.proto.task_pb2 import Task, TaskRes # pylint: disable=E0611 + +from .driver import Driver + + +class TestDriver(unittest.TestCase): + """Tests for `Driver` class.""" + + def setUp(self) -> None: + """Initialize mock GrpcDriver and Driver instance before each test.""" + mock_response = Mock() + mock_response.run_id = 61016 + self.mock_grpc_driver = Mock() + self.mock_grpc_driver.create_run.return_value = mock_response + self.patcher = patch( + "flwr.server.driver.driver.GrpcDriver", return_value=self.mock_grpc_driver + ) + self.patcher.start() + self.driver = Driver() + + def tearDown(self) -> None: + """Cleanup after each test.""" + self.patcher.stop() + + def test_check_and_init_grpc_driver_already_initialized(self) -> None: + """Test that GrpcDriver doesn't initialize if run is created.""" + # Prepare + self.driver.grpc_driver = self.mock_grpc_driver + self.driver.run_id = 61016 + + # Execute + # pylint: disable-next=protected-access + self.driver._get_grpc_driver_and_run_id() + + # Assert + self.mock_grpc_driver.connect.assert_not_called() + + def test_check_and_init_grpc_driver_needs_initialization(self) -> None: + """Test GrpcDriver initialization when run is not created.""" + # Execute + # pylint: disable-next=protected-access + self.driver._get_grpc_driver_and_run_id() + + # Assert + self.mock_grpc_driver.connect.assert_called_once() + self.assertEqual(self.driver.run_id, 61016) + + def test_get_nodes(self) -> None: + """Test retrieval of nodes.""" + # Prepare + mock_response = Mock() + mock_response.nodes = [Mock(node_id=404), Mock(node_id=200)] + self.mock_grpc_driver.get_nodes.return_value = mock_response + + # Execute + node_ids = self.driver.get_node_ids() + args, kwargs = self.mock_grpc_driver.get_nodes.call_args + + # Assert + self.mock_grpc_driver.connect.assert_called_once() + self.assertEqual(len(args), 1) + self.assertEqual(len(kwargs), 0) + self.assertIsInstance(args[0], GetNodesRequest) + self.assertEqual(args[0].run_id, 61016) + self.assertEqual(node_ids, [404, 200]) + + def test_push_messages_valid(self) -> None: + """Test pushing valid messages.""" + # Prepare + mock_response = Mock(task_ids=["id1", "id2"]) + self.mock_grpc_driver.push_task_ins.return_value = mock_response + msgs = [ + self.driver.create_message(RecordSet(), "", 0, "", "") for _ in range(2) + ] + + # Execute + msg_ids = self.driver.push_messages(msgs) + args, kwargs = self.mock_grpc_driver.push_task_ins.call_args + + # Assert + self.mock_grpc_driver.connect.assert_called_once() + self.assertEqual(len(args), 1) + self.assertEqual(len(kwargs), 0) + self.assertIsInstance(args[0], PushTaskInsRequest) + self.assertEqual(msg_ids, mock_response.task_ids) + for task_ins in args[0].task_ins_list: + self.assertEqual(task_ins.run_id, 61016) + + def test_push_messages_invalid(self) -> None: + """Test pushing invalid messages.""" + # Prepare + mock_response = Mock(task_ids=["id1", "id2"]) + self.mock_grpc_driver.push_task_ins.return_value = mock_response + msgs = [ + self.driver.create_message(RecordSet(), "", 0, "", "") for _ in range(2) + ] + # Use invalid run_id + msgs[1].metadata._run_id += 1 # pylint: disable=protected-access + + # Execute and assert + with self.assertRaises(ValueError): + self.driver.push_messages(msgs) + + def test_pull_messages_with_given_message_ids(self) -> None: + """Test pulling messages with specific message IDs.""" + # Prepare + mock_response = Mock() + # A Message must have either content or error set so we prepare + # two tasks that contain these. + mock_response.task_res_list = [ + TaskRes( + task=Task(ancestry=["id2"], recordset=recordset_to_proto(RecordSet())) + ), + TaskRes(task=Task(ancestry=["id3"], error=error_to_proto(Error(code=0)))), + ] + self.mock_grpc_driver.pull_task_res.return_value = mock_response + msg_ids = ["id1", "id2", "id3"] + + # Execute + msgs = self.driver.pull_messages(msg_ids) + reply_tos = {msg.metadata.reply_to_message for msg in msgs} + args, kwargs = self.mock_grpc_driver.pull_task_res.call_args + + # Assert + self.mock_grpc_driver.connect.assert_called_once() + self.assertEqual(len(args), 1) + self.assertEqual(len(kwargs), 0) + self.assertIsInstance(args[0], PullTaskResRequest) + self.assertEqual(args[0].task_ids, msg_ids) + self.assertEqual(reply_tos, {"id2", "id3"}) + + def test_send_and_receive_messages_complete(self) -> None: + """Test send and receive all messages successfully.""" + # Prepare + mock_response = Mock(task_ids=["id1"]) + self.mock_grpc_driver.push_task_ins.return_value = mock_response + # The response message must include either `content` (i.e. a recordset) or + # an `Error`. We choose the latter in this case + error_proto = error_to_proto(Error(code=0)) + mock_response = Mock( + task_res_list=[TaskRes(task=Task(ancestry=["id1"], error=error_proto))] + ) + self.mock_grpc_driver.pull_task_res.return_value = mock_response + msgs = [self.driver.create_message(RecordSet(), "", 0, "", "")] + + # Execute + ret_msgs = list(self.driver.send_and_receive(msgs)) + + # Assert + self.assertEqual(len(ret_msgs), 1) + self.assertEqual(ret_msgs[0].metadata.reply_to_message, "id1") + + def test_send_and_receive_messages_timeout(self) -> None: + """Test send and receive messages but time out.""" + # Prepare + sleep_fn = time.sleep + mock_response = Mock(task_ids=["id1"]) + self.mock_grpc_driver.push_task_ins.return_value = mock_response + mock_response = Mock(task_res_list=[]) + self.mock_grpc_driver.pull_task_res.return_value = mock_response + msgs = [self.driver.create_message(RecordSet(), "", 0, "", "")] + + # Execute + with patch("time.sleep", side_effect=lambda t: sleep_fn(t * 0.01)): + start_time = time.time() + ret_msgs = list(self.driver.send_and_receive(msgs, timeout=0.15)) + + # Assert + self.assertLess(time.time() - start_time, 0.2) + self.assertEqual(len(ret_msgs), 0) + + def test_del_with_initialized_driver(self) -> None: + """Test cleanup behavior when Driver is initialized.""" + # Prepare + # pylint: disable-next=protected-access + self.driver._get_grpc_driver_and_run_id() + + # Execute + self.driver.close() + + # Assert + self.mock_grpc_driver.disconnect.assert_called_once() + + def test_del_with_uninitialized_driver(self) -> None: + """Test cleanup behavior when Driver is not initialized.""" + # Execute + self.driver.close() + + # Assert + self.mock_grpc_driver.disconnect.assert_not_called() diff --git a/src/py/flwr/driver/grpc_driver.py b/src/py/flwr/server/driver/grpc_driver.py similarity index 76% rename from src/py/flwr/driver/grpc_driver.py rename to src/py/flwr/server/driver/grpc_driver.py index 7dd0a0f501c5..b6e2b2602cd5 100644 --- a/src/py/flwr/driver/grpc_driver.py +++ b/src/py/flwr/server/driver/grpc_driver.py @@ -15,7 +15,7 @@ """Flower driver service client.""" -from logging import ERROR, INFO, WARNING +from logging import DEBUG, ERROR, WARNING from typing import Optional import grpc @@ -23,9 +23,9 @@ from flwr.common import EventType, event from flwr.common.grpc import create_channel from flwr.common.logger import log -from flwr.proto.driver_pb2 import ( - CreateWorkloadRequest, - CreateWorkloadResponse, +from flwr.proto.driver_pb2 import ( # pylint: disable=E0611 + CreateRunRequest, + CreateRunResponse, GetNodesRequest, GetNodesResponse, PullTaskResRequest, @@ -33,7 +33,7 @@ PushTaskInsRequest, PushTaskInsResponse, ) -from flwr.proto.driver_pb2_grpc import DriverStub +from flwr.proto.driver_pb2_grpc import DriverStub # pylint: disable=E0611 DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091" @@ -51,10 +51,10 @@ class GrpcDriver: def __init__( self, driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER, - certificates: Optional[bytes] = None, + root_certificates: Optional[bytes] = None, ) -> None: self.driver_service_address = driver_service_address - self.certificates = certificates + self.root_certificates = root_certificates self.channel: Optional[grpc.Channel] = None self.stub: Optional[DriverStub] = None @@ -66,33 +66,33 @@ def connect(self) -> None: return self.channel = create_channel( server_address=self.driver_service_address, - insecure=(self.certificates is None), - root_certificates=self.certificates, + insecure=(self.root_certificates is None), + root_certificates=self.root_certificates, ) self.stub = DriverStub(self.channel) - log(INFO, "[Driver] Connected to %s", self.driver_service_address) + log(DEBUG, "[Driver] Connected to %s", self.driver_service_address) def disconnect(self) -> None: """Disconnect from the Driver API.""" event(EventType.DRIVER_DISCONNECT) if self.channel is None or self.stub is None: - log(WARNING, "Already disconnected") + log(DEBUG, "Already disconnected") return channel = self.channel self.channel = None self.stub = None channel.close() - log(INFO, "[Driver] Disconnected") + log(DEBUG, "[Driver] Disconnected") - def create_workload(self, req: CreateWorkloadRequest) -> CreateWorkloadResponse: - """Request for workload ID.""" + def create_run(self, req: CreateRunRequest) -> CreateRunResponse: + """Request for run ID.""" # Check if channel is open if self.stub is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise Exception("`GrpcDriver` instance not connected") + raise ConnectionError("`GrpcDriver` instance not connected") # Call Driver API - res: CreateWorkloadResponse = self.stub.CreateWorkload(request=req) + res: CreateRunResponse = self.stub.CreateRun(request=req) return res def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse: @@ -100,7 +100,7 @@ def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse: # Check if channel is open if self.stub is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise Exception("`GrpcDriver` instance not connected") + raise ConnectionError("`GrpcDriver` instance not connected") # Call gRPC Driver API res: GetNodesResponse = self.stub.GetNodes(request=req) @@ -111,7 +111,7 @@ def push_task_ins(self, req: PushTaskInsRequest) -> PushTaskInsResponse: # Check if channel is open if self.stub is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise Exception("`GrpcDriver` instance not connected") + raise ConnectionError("`GrpcDriver` instance not connected") # Call gRPC Driver API res: PushTaskInsResponse = self.stub.PushTaskIns(request=req) @@ -122,7 +122,7 @@ def pull_task_res(self, req: PullTaskResRequest) -> PullTaskResResponse: # Check if channel is open if self.stub is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise Exception("`GrpcDriver` instance not connected") + raise ConnectionError("`GrpcDriver` instance not connected") # Call Driver API res: PullTaskResResponse = self.stub.PullTaskRes(request=req) diff --git a/src/py/flwr/server/fleet/grpc_bidi/driver_client_manager.py b/src/py/flwr/server/fleet/grpc_bidi/driver_client_manager.py deleted file mode 100644 index dc94bf3912d7..000000000000 --- a/src/py/flwr/server/fleet/grpc_bidi/driver_client_manager.py +++ /dev/null @@ -1,140 +0,0 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Flower DriverClientManager.""" - - -import threading -from typing import Dict, List, Optional, Set, Tuple - -from flwr.server.client_manager import ClientManager -from flwr.server.client_proxy import ClientProxy -from flwr.server.criterion import Criterion -from flwr.server.state import State, StateFactory - -from .ins_scheduler import InsScheduler - - -class DriverClientManager(ClientManager): - """Provides a pool of available clients.""" - - def __init__(self, state_factory: StateFactory) -> None: - self._cv = threading.Condition() - self.nodes: Dict[str, Tuple[int, InsScheduler]] = {} - self.state_factory = state_factory - - def __len__(self) -> int: - """Return the number of available clients. - - Returns - ------- - num_available : int - The number of currently available clients. - """ - return len(self.nodes) - - def num_available(self) -> int: - """Return the number of available clients. - - Returns - ------- - num_available : int - The number of currently available clients. - """ - return len(self) - - def register(self, client: ClientProxy) -> bool: - """Register Flower ClientProxy instance. - - Parameters - ---------- - client : flwr.server.client_proxy.ClientProxy - - Returns - ------- - success : bool - Indicating if registration was successful. False if ClientProxy is - already registered or can not be registered for any reason. - """ - if client.cid in self.nodes: - return False - - # Create node in State - state: State = self.state_factory.state() - client.node_id = state.create_node() - - # Create and start the instruction scheduler - ins_scheduler = InsScheduler( - client_proxy=client, - state_factory=self.state_factory, - ) - ins_scheduler.start() - - # Store cid, node_id, and InsScheduler - self.nodes[client.cid] = (client.node_id, ins_scheduler) - - with self._cv: - self._cv.notify_all() - - return True - - def unregister(self, client: ClientProxy) -> None: - """Unregister Flower ClientProxy instance. - - This method is idempotent. - - Parameters - ---------- - client : flwr.server.client_proxy.ClientProxy - """ - if client.cid in self.nodes: - node_id, ins_scheduler = self.nodes[client.cid] - del self.nodes[client.cid] - ins_scheduler.stop() - - # Delete node_id in State - state: State = self.state_factory.state() - state.delete_node(node_id=node_id) - - with self._cv: - self._cv.notify_all() - - def all_ids(self) -> Set[int]: - """Return all available node ids. - - Returns - ------- - ids : Set[int] - The IDs of all currently available nodes. - """ - return {node_id for _, (node_id, _) in self.nodes.items()} - - # --- Unimplemented methods ----------------------------------------------- - - def all(self) -> Dict[str, ClientProxy]: - """Not implemented.""" - raise NotImplementedError() - - def wait_for(self, num_clients: int, timeout: int = 86400) -> bool: - """Not implemented.""" - raise NotImplementedError() - - def sample( - self, - num_clients: int, - min_num_clients: Optional[int] = None, - criterion: Optional[Criterion] = None, - ) -> List[ClientProxy]: - """Not implemented.""" - raise NotImplementedError() diff --git a/src/py/flwr/server/fleet/grpc_bidi/ins_scheduler.py b/src/py/flwr/server/fleet/grpc_bidi/ins_scheduler.py deleted file mode 100644 index 1c737d31c7fc..000000000000 --- a/src/py/flwr/server/fleet/grpc_bidi/ins_scheduler.py +++ /dev/null @@ -1,171 +0,0 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Instruction scheduler for the legacy gRPC transport stack.""" - - -import threading -import time -from logging import DEBUG, ERROR -from typing import Dict, List, Optional - -from flwr.client.message_handler.task_handler import configure_task_res -from flwr.common import EvaluateRes, FitRes, GetParametersRes, GetPropertiesRes, serde -from flwr.common.logger import log -from flwr.proto.node_pb2 import Node -from flwr.proto.task_pb2 import Task, TaskIns, TaskRes -from flwr.proto.transport_pb2 import ClientMessage, ServerMessage -from flwr.server.client_proxy import ClientProxy -from flwr.server.state import State, StateFactory - - -class InsScheduler: - """Schedule ClientProxy calls on a background thread.""" - - def __init__(self, client_proxy: ClientProxy, state_factory: StateFactory): - self.client_proxy = client_proxy - self.state_factory = state_factory - self.worker_thread: Optional[threading.Thread] = None - self.shared_memory_state = {"stop": False} - - def start(self) -> None: - """Start the worker thread.""" - self.worker_thread = threading.Thread( - target=_worker, - args=( - self.client_proxy, - self.shared_memory_state, - self.state_factory, - ), - ) - self.worker_thread.start() - - def stop(self) -> None: - """Stop the worker thread.""" - if self.worker_thread is None: - log(ERROR, "InsScheduler.stop called, but worker_thread is None") - return - self.shared_memory_state["stop"] = True - self.worker_thread.join() - self.worker_thread = None - self.shared_memory_state["stop"] = False - - -def _worker( - client_proxy: ClientProxy, - shared_memory_state: Dict[str, bool], - state_factory: StateFactory, -) -> None: - """Sequentially call ClientProxy methods to process outstanding tasks.""" - log(DEBUG, "Worker for node %i started", client_proxy.node_id) - - state: State = state_factory.state() - while not shared_memory_state["stop"]: - log(DEBUG, "Worker for node %i checking state", client_proxy.node_id) - - # Step 1: pull *Ins (next task) out of `state` - task_ins_list: List[TaskIns] = state.get_task_ins( - node_id=client_proxy.node_id, - limit=1, - ) - if not task_ins_list: - log(DEBUG, "Worker for node %i: no task found", client_proxy.node_id) - time.sleep(3) - continue - - task_ins = task_ins_list[0] - log( - DEBUG, - "Worker for node %i: FOUND task %s", - client_proxy.node_id, - task_ins.task_id, - ) - - # Step 2: call client_proxy.{fit,evaluate,...} - server_message = task_ins.task.legacy_server_message - client_message_proto = _call_client_proxy( - client_proxy=client_proxy, - server_message=server_message, - timeout=None, - ) - - # Step 3: wrap FitRes in a ClientMessage in a Task in a TaskRes - task_res = configure_task_res( - TaskRes(task=Task(legacy_client_message=client_message_proto)), - task_ins, - Node(node_id=client_proxy.node_id, anonymous=False), - ) - - # Step 4: write *Res (result) back to `state` - state.store_task_res(task_res=task_res) - - # Exit worker thread - log(DEBUG, "Worker for node %i stopped", client_proxy.node_id) - - -def _call_client_proxy( - client_proxy: ClientProxy, server_message: ServerMessage, timeout: Optional[float] -) -> ClientMessage: - """.""" - # pylint: disable=too-many-locals - - field = server_message.WhichOneof("msg") - - if field == "get_properties_ins": - get_properties_ins = serde.get_properties_ins_from_proto( - msg=server_message.get_properties_ins - ) - get_properties_res: GetPropertiesRes = client_proxy.get_properties( - ins=get_properties_ins, - timeout=timeout, - ) - get_properties_res_proto = serde.get_properties_res_to_proto( - res=get_properties_res - ) - return ClientMessage(get_properties_res=get_properties_res_proto) - - if field == "get_parameters_ins": - get_parameters_ins = serde.get_parameters_ins_from_proto( - msg=server_message.get_parameters_ins - ) - get_parameters_res: GetParametersRes = client_proxy.get_parameters( - ins=get_parameters_ins, - timeout=timeout, - ) - get_parameters_res_proto = serde.get_parameters_res_to_proto( - res=get_parameters_res - ) - return ClientMessage(get_parameters_res=get_parameters_res_proto) - - if field == "fit_ins": - fit_ins = serde.fit_ins_from_proto(msg=server_message.fit_ins) - fit_res: FitRes = client_proxy.fit( - ins=fit_ins, - timeout=timeout, - ) - fit_res_proto = serde.fit_res_to_proto(res=fit_res) - return ClientMessage(fit_res=fit_res_proto) - - if field == "evaluate_ins": - evaluate_ins = serde.evaluate_ins_from_proto(msg=server_message.evaluate_ins) - evaluate_res: EvaluateRes = client_proxy.evaluate( - ins=evaluate_ins, - timeout=timeout, - ) - evaluate_res_proto = serde.evaluate_res_to_proto(res=evaluate_res) - return ClientMessage(evaluate_res=evaluate_res_proto) - - raise Exception( - "Unsupported instruction in ServerMessage, cannot deserialize from ProtoBuf" - ) diff --git a/src/py/flwr/server/history.py b/src/py/flwr/server/history.py index ad5f5d0fc870..c4298911d97b 100644 --- a/src/py/flwr/server/history.py +++ b/src/py/flwr/server/history.py @@ -15,6 +15,7 @@ """Training history.""" +import pprint from functools import reduce from typing import Dict, List, Tuple @@ -90,29 +91,35 @@ def __repr__(self) -> str: """ rep = "" if self.losses_distributed: - rep += "History (loss, distributed):\n" + reduce( - lambda a, b: a + b, - [ - f"\tround {server_round}: {loss}\n" - for server_round, loss in self.losses_distributed - ], + rep += "History (loss, distributed):\n" + pprint.pformat( + reduce( + lambda a, b: a + b, + [ + f"\tround {server_round}: {loss}\n" + for server_round, loss in self.losses_distributed + ], + ) ) if self.losses_centralized: - rep += "History (loss, centralized):\n" + reduce( - lambda a, b: a + b, - [ - f"\tround {server_round}: {loss}\n" - for server_round, loss in self.losses_centralized - ], + rep += "History (loss, centralized):\n" + pprint.pformat( + reduce( + lambda a, b: a + b, + [ + f"\tround {server_round}: {loss}\n" + for server_round, loss in self.losses_centralized + ], + ) ) if self.metrics_distributed_fit: - rep += "History (metrics, distributed, fit):\n" + str( + rep += "History (metrics, distributed, fit):\n" + pprint.pformat( self.metrics_distributed_fit ) if self.metrics_distributed: - rep += "History (metrics, distributed, evaluate):\n" + str( + rep += "History (metrics, distributed, evaluate):\n" + pprint.pformat( self.metrics_distributed ) if self.metrics_centralized: - rep += "History (metrics, centralized):\n" + str(self.metrics_centralized) + rep += "History (metrics, centralized):\n" + pprint.pformat( + self.metrics_centralized + ) return rep diff --git a/src/py/flwr/server/run_serverapp.py b/src/py/flwr/server/run_serverapp.py new file mode 100644 index 000000000000..2f0f1185847e --- /dev/null +++ b/src/py/flwr/server/run_serverapp.py @@ -0,0 +1,187 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Run ServerApp.""" + + +import argparse +import sys +from logging import DEBUG, INFO, WARN +from pathlib import Path +from typing import Optional + +from flwr.common import Context, EventType, RecordSet, event +from flwr.common.logger import log, update_console_handler +from flwr.common.object_ref import load_app + +from .driver.driver import Driver +from .server_app import LoadServerAppError, ServerApp + + +def run( + driver: Driver, + server_app_dir: str, + server_app_attr: Optional[str] = None, + loaded_server_app: Optional[ServerApp] = None, +) -> None: + """Run ServerApp with a given Driver.""" + if not (server_app_attr is None) ^ (loaded_server_app is None): + raise ValueError( + "Either `server_app_attr` or `loaded_server_app` should be set " + "but not both. " + ) + + if server_app_dir is not None: + sys.path.insert(0, server_app_dir) + + # Load ServerApp if needed + def _load() -> ServerApp: + if server_app_attr: + server_app: ServerApp = load_app(server_app_attr, LoadServerAppError) + + if not isinstance(server_app, ServerApp): + raise LoadServerAppError( + f"Attribute {server_app_attr} is not of type {ServerApp}", + ) from None + + if loaded_server_app: + server_app = loaded_server_app + return server_app + + server_app = _load() + + # Initialize Context + context = Context(state=RecordSet()) + + # Call ServerApp + server_app(driver=driver, context=context) + + log(DEBUG, "ServerApp finished running.") + + +def run_server_app() -> None: + """Run Flower server app.""" + event(EventType.RUN_SERVER_APP_ENTER) + + args = _parse_args_run_server_app().parse_args() + + update_console_handler( + level=DEBUG if args.verbose else INFO, + timestamps=args.verbose, + colored=True, + ) + + # Obtain certificates + if args.insecure: + if args.root_certificates is not None: + sys.exit( + "Conflicting options: The '--insecure' flag disables HTTPS, " + "but '--root-certificates' was also specified. Please remove " + "the '--root-certificates' option when running in insecure mode, " + "or omit '--insecure' to use HTTPS." + ) + log( + WARN, + "Option `--insecure` was set. " + "Starting insecure HTTP client connected to %s.", + args.server, + ) + root_certificates = None + else: + # Load the certificates if provided, or load the system certificates + cert_path = args.root_certificates + if cert_path is None: + root_certificates = None + else: + root_certificates = Path(cert_path).read_bytes() + log( + DEBUG, + "Starting secure HTTPS client connected to %s " + "with the following certificates: %s.", + args.server, + cert_path, + ) + + log( + DEBUG, + "Flower will load ServerApp `%s`", + getattr(args, "server-app"), + ) + + log( + DEBUG, + "root_certificates: `%s`", + root_certificates, + ) + + server_app_dir = args.dir + server_app_attr = getattr(args, "server-app") + + # Initialize Driver + driver = Driver( + driver_service_address=args.server, + root_certificates=root_certificates, + ) + + # Run the Server App with the Driver + run(driver=driver, server_app_dir=server_app_dir, server_app_attr=server_app_attr) + + # Clean up + driver.close() + + event(EventType.RUN_SERVER_APP_LEAVE) + + +def _parse_args_run_server_app() -> argparse.ArgumentParser: + """Parse flower-server-app command line arguments.""" + parser = argparse.ArgumentParser( + description="Start a Flower server app", + ) + + parser.add_argument( + "server-app", + help="For example: `server:app` or `project.package.module:wrapper.app`", + ) + parser.add_argument( + "--insecure", + action="store_true", + help="Run the server app without HTTPS. By default, the app runs with " + "HTTPS enabled. Use this flag only if you understand the risks.", + ) + parser.add_argument( + "--verbose", + action="store_true", + help="Set the logging to `DEBUG`.", + ) + parser.add_argument( + "--root-certificates", + metavar="ROOT_CERT", + type=str, + help="Specifies the path to the PEM-encoded root certificate file for " + "establishing secure HTTPS connections.", + ) + parser.add_argument( + "--server", + default="0.0.0.0:9091", + help="Server address", + ) + parser.add_argument( + "--dir", + default="", + help="Add specified directory to the PYTHONPATH and load Flower " + "app from there." + " Default: current working directory.", + ) + + return parser diff --git a/src/py/flwr/server/server.py b/src/py/flwr/server/server.py index cf3a4d9aa07c..981325a6df08 100644 --- a/src/py/flwr/server/server.py +++ b/src/py/flwr/server/server.py @@ -16,8 +16,9 @@ import concurrent.futures +import io import timeit -from logging import DEBUG, INFO +from logging import INFO, WARN from typing import Dict, List, Optional, Tuple, Union from flwr.common import ( @@ -33,11 +34,13 @@ ) from flwr.common.logger import log from flwr.common.typing import GetParametersIns -from flwr.server.client_manager import ClientManager +from flwr.server.client_manager import ClientManager, SimpleClientManager from flwr.server.client_proxy import ClientProxy from flwr.server.history import History from flwr.server.strategy import FedAvg, Strategy +from .server_config import ServerConfig + FitResultsAndFailures = Tuple[ List[Tuple[ClientProxy, FitRes]], List[Union[Tuple[ClientProxy, FitRes], BaseException]], @@ -81,14 +84,14 @@ def client_manager(self) -> ClientManager: return self._client_manager # pylint: disable=too-many-locals - def fit(self, num_rounds: int, timeout: Optional[float]) -> History: + def fit(self, num_rounds: int, timeout: Optional[float]) -> Tuple[History, float]: """Run federated averaging for a number of rounds.""" history = History() # Initialize parameters - log(INFO, "Initializing global parameters") - self.parameters = self._get_initial_parameters(timeout=timeout) - log(INFO, "Evaluating initial parameters") + log(INFO, "[INIT]") + self.parameters = self._get_initial_parameters(server_round=0, timeout=timeout) + log(INFO, "Evaluating initial global parameters") res = self.strategy.evaluate(0, parameters=self.parameters) if res is not None: log( @@ -101,10 +104,11 @@ def fit(self, num_rounds: int, timeout: Optional[float]) -> History: history.add_metrics_centralized(server_round=0, metrics=res[1]) # Run federated learning for num_rounds - log(INFO, "FL starting") start_time = timeit.default_timer() for current_round in range(1, num_rounds + 1): + log(INFO, "") + log(INFO, "[ROUND %s]", current_round) # Train model and replace previous global model res_fit = self.fit_round( server_round=current_round, @@ -150,8 +154,7 @@ def fit(self, num_rounds: int, timeout: Optional[float]) -> History: # Bookkeeping end_time = timeit.default_timer() elapsed = end_time - start_time - log(INFO, "FL finished in %s", elapsed) - return history + return history, elapsed def evaluate_round( self, @@ -168,12 +171,11 @@ def evaluate_round( client_manager=self._client_manager, ) if not client_instructions: - log(INFO, "evaluate_round %s: no clients selected, cancel", server_round) + log(INFO, "configure_evaluate: no clients selected, skipping evaluation") return None log( - DEBUG, - "evaluate_round %s: strategy sampled %s clients (out of %s)", - server_round, + INFO, + "configure_evaluate: strategy sampled %s clients (out of %s)", len(client_instructions), self._client_manager.num_available(), ) @@ -183,11 +185,11 @@ def evaluate_round( client_instructions, max_workers=self.max_workers, timeout=timeout, + group_id=server_round, ) log( - DEBUG, - "evaluate_round %s received %s results and %s failures", - server_round, + INFO, + "aggregate_evaluate: received %s results and %s failures", len(results), len(failures), ) @@ -217,12 +219,11 @@ def fit_round( ) if not client_instructions: - log(INFO, "fit_round %s: no clients selected, cancel", server_round) + log(INFO, "configure_fit: no clients selected, cancel") return None log( - DEBUG, - "fit_round %s: strategy sampled %s clients (out of %s)", - server_round, + INFO, + "configure_fit: strategy sampled %s clients (out of %s)", len(client_instructions), self._client_manager.num_available(), ) @@ -232,11 +233,11 @@ def fit_round( client_instructions=client_instructions, max_workers=self.max_workers, timeout=timeout, + group_id=server_round, ) log( - DEBUG, - "fit_round %s received %s results and %s failures", - server_round, + INFO, + "aggregate_fit: received %s results and %s failures", len(results), len(failures), ) @@ -262,21 +263,25 @@ def disconnect_all_clients(self, timeout: Optional[float]) -> None: timeout=timeout, ) - def _get_initial_parameters(self, timeout: Optional[float]) -> Parameters: + def _get_initial_parameters( + self, server_round: int, timeout: Optional[float] + ) -> Parameters: """Get initial parameters from one of the available clients.""" # Server-side parameter initialization parameters: Optional[Parameters] = self.strategy.initialize_parameters( client_manager=self._client_manager ) if parameters is not None: - log(INFO, "Using initial parameters provided by strategy") + log(INFO, "Using initial global parameters provided by strategy") return parameters # Get initial parameters from one of the clients log(INFO, "Requesting initial parameters from one random client") random_client = self._client_manager.sample(1)[0] ins = GetParametersIns(config={}) - get_parameters_res = random_client.get_parameters(ins=ins, timeout=timeout) + get_parameters_res = random_client.get_parameters( + ins=ins, timeout=timeout, group_id=server_round + ) log(INFO, "Received initial parameters from one random client") return get_parameters_res.parameters @@ -319,6 +324,7 @@ def reconnect_client( disconnect = client.reconnect( reconnect, timeout=timeout, + group_id=None, ) return client, disconnect @@ -327,11 +333,12 @@ def fit_clients( client_instructions: List[Tuple[ClientProxy, FitIns]], max_workers: Optional[int], timeout: Optional[float], + group_id: int, ) -> FitResultsAndFailures: """Refine parameters concurrently on all selected clients.""" with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: submitted_fs = { - executor.submit(fit_client, client_proxy, ins, timeout) + executor.submit(fit_client, client_proxy, ins, timeout, group_id) for client_proxy, ins in client_instructions } finished_fs, _ = concurrent.futures.wait( @@ -350,10 +357,10 @@ def fit_clients( def fit_client( - client: ClientProxy, ins: FitIns, timeout: Optional[float] + client: ClientProxy, ins: FitIns, timeout: Optional[float], group_id: int ) -> Tuple[ClientProxy, FitRes]: """Refine parameters on a single client.""" - fit_res = client.fit(ins, timeout=timeout) + fit_res = client.fit(ins, timeout=timeout, group_id=group_id) return client, fit_res @@ -386,11 +393,12 @@ def evaluate_clients( client_instructions: List[Tuple[ClientProxy, EvaluateIns]], max_workers: Optional[int], timeout: Optional[float], + group_id: int, ) -> EvaluateResultsAndFailures: """Evaluate parameters concurrently on all selected clients.""" with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: submitted_fs = { - executor.submit(evaluate_client, client_proxy, ins, timeout) + executor.submit(evaluate_client, client_proxy, ins, timeout, group_id) for client_proxy, ins in client_instructions } finished_fs, _ = concurrent.futures.wait( @@ -412,9 +420,10 @@ def evaluate_client( client: ClientProxy, ins: EvaluateIns, timeout: Optional[float], + group_id: int, ) -> Tuple[ClientProxy, EvaluateRes]: """Evaluate parameters on a single client.""" - evaluate_res = client.evaluate(ins, timeout=timeout) + evaluate_res = client.evaluate(ins, timeout=timeout, group_id=group_id) return client, evaluate_res @@ -441,3 +450,51 @@ def _handle_finished_future_after_evaluate( # Not successful, client returned a result where the status code is not OK failures.append(result) + + +def init_defaults( + server: Optional[Server], + config: Optional[ServerConfig], + strategy: Optional[Strategy], + client_manager: Optional[ClientManager], +) -> Tuple[Server, ServerConfig]: + """Create server instance if none was given.""" + if server is None: + if client_manager is None: + client_manager = SimpleClientManager() + if strategy is None: + strategy = FedAvg() + server = Server(client_manager=client_manager, strategy=strategy) + elif strategy is not None: + log(WARN, "Both server and strategy were provided, ignoring strategy") + + # Set default config values + if config is None: + config = ServerConfig() + + return server, config + + +def run_fl( + server: Server, + config: ServerConfig, +) -> History: + """Train a model on the given server and return the History object.""" + hist, elapsed_time = server.fit( + num_rounds=config.num_rounds, timeout=config.round_timeout + ) + + log(INFO, "") + log(INFO, "[SUMMARY]") + log(INFO, "Run finished %s rounds in %.2fs", config.num_rounds, elapsed_time) + for idx, line in enumerate(io.StringIO(str(hist))): + if idx == 0: + log(INFO, "%s", line.strip("\n")) + else: + log(INFO, "\t%s", line.strip("\n")) + log(INFO, "") + + # Graceful shutdown + server.disconnect_all_clients(timeout=config.round_timeout) + + return hist diff --git a/src/py/flwr/server/server_app.py b/src/py/flwr/server/server_app.py new file mode 100644 index 000000000000..1b2eab87fdaa --- /dev/null +++ b/src/py/flwr/server/server_app.py @@ -0,0 +1,133 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower ServerApp.""" + + +from typing import Callable, Optional + +from flwr.common import Context, RecordSet +from flwr.server.strategy import Strategy + +from .client_manager import ClientManager +from .compat import start_driver +from .driver import Driver +from .server import Server +from .server_config import ServerConfig +from .typing import ServerAppCallable + + +class ServerApp: + """Flower ServerApp. + + Examples + -------- + Use the `ServerApp` with an existing `Strategy`: + + >>> server_config = ServerConfig(num_rounds=3) + >>> strategy = FedAvg() + >>> + >>> app = ServerApp() + >>> server_config=server_config, + >>> strategy=strategy, + >>> ) + + Use the `ServerApp` with a custom main function: + + >>> app = ServerApp() + >>> + >>> @app.main() + >>> def main(driver: Driver, context: Context) -> None: + >>> print("ServerApp running") + """ + + def __init__( + self, + server: Optional[Server] = None, + config: Optional[ServerConfig] = None, + strategy: Optional[Strategy] = None, + client_manager: Optional[ClientManager] = None, + ) -> None: + self._server = server + self._config = config + self._strategy = strategy + self._client_manager = client_manager + self._main: Optional[ServerAppCallable] = None + + def __call__(self, driver: Driver, context: Context) -> None: + """Execute `ServerApp`.""" + # Compatibility mode + if not self._main: + start_driver( + server=self._server, + config=self._config, + strategy=self._strategy, + client_manager=self._client_manager, + driver=driver, + ) + return + + # New execution mode + context = Context(state=RecordSet()) + self._main(driver, context) + + def main(self) -> Callable[[ServerAppCallable], ServerAppCallable]: + """Return a decorator that registers the main fn with the server app. + + Examples + -------- + >>> app = ServerApp() + >>> + >>> @app.main() + >>> def main(driver: Driver, context: Context) -> None: + >>> print("ServerApp running") + """ + + def main_decorator(main_fn: ServerAppCallable) -> ServerAppCallable: + """Register the main fn with the ServerApp object.""" + if self._server or self._config or self._strategy or self._client_manager: + raise ValueError( + """Use either a custom main function or a `Strategy`, but not both. + + Use the `ServerApp` with an existing `Strategy`: + + >>> server_config = ServerConfig(num_rounds=3) + >>> strategy = FedAvg() + >>> + >>> app = ServerApp() + >>> server_config=server_config, + >>> strategy=strategy, + >>> ) + + Use the `ServerApp` with a custom main function: + + >>> app = ServerApp() + >>> + >>> @app.main() + >>> def main(driver: Driver, context: Context) -> None: + >>> print("ServerApp running") + """, + ) + + # Register provided function with the ServerApp object + self._main = main_fn + + # Return provided function unmodified + return main_fn + + return main_decorator + + +class LoadServerAppError(Exception): + """Error when trying to load `ServerApp`.""" diff --git a/src/py/flwr/server/server_app_test.py b/src/py/flwr/server/server_app_test.py new file mode 100644 index 000000000000..38c0d6240d90 --- /dev/null +++ b/src/py/flwr/server/server_app_test.py @@ -0,0 +1,62 @@ +# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for ServerApp.""" + + +from unittest.mock import MagicMock + +import pytest + +from flwr.common import Context, RecordSet +from flwr.server import ServerApp, ServerConfig +from flwr.server.driver import Driver + + +def test_server_app_custom_mode() -> None: + """Test sampling w/o criterion.""" + # Prepare + app = ServerApp() + driver = MagicMock() + context = Context(state=RecordSet()) + + called = {"called": False} + + # pylint: disable=unused-argument + @app.main() + def custom_main(driver: Driver, context: Context) -> None: + called["called"] = True + + # pylint: enable=unused-argument + + # Execute + app(driver, context) + + # Assert + assert called["called"] + + +def test_server_app_exception_when_both_modes() -> None: + """Test ServerApp error when both compat mode and custom fns are used.""" + # Prepare + app = ServerApp(config=ServerConfig(num_rounds=3)) + + # Execute and assert + with pytest.raises(ValueError): + # pylint: disable=unused-argument + @app.main() + def custom_main(driver: Driver, context: Context) -> None: + pass + + # pylint: enable=unused-argument diff --git a/src/py/flwr/server/server_config.py b/src/py/flwr/server/server_config.py new file mode 100644 index 000000000000..c47367eab4c0 --- /dev/null +++ b/src/py/flwr/server/server_config.py @@ -0,0 +1,40 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower ServerConfig.""" + + +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class ServerConfig: + """Flower server config. + + All attributes have default values which allows users to configure just the ones + they care about. + """ + + num_rounds: int = 1 + round_timeout: Optional[float] = None + + def __repr__(self) -> str: + """Return the string representation of the ServerConfig.""" + timeout_string = ( + "no round_timeout" + if self.round_timeout is None + else f"round_timeout={self.round_timeout}s" + ) + return f"num_rounds={self.num_rounds}, {timeout_string}" diff --git a/src/py/flwr/server/server_test.py b/src/py/flwr/server/server_test.py index 63ec1021ff5c..274e5289fee1 100644 --- a/src/py/flwr/server/server_test.py +++ b/src/py/flwr/server/server_test.py @@ -45,18 +45,20 @@ class SuccessClient(ClientProxy): """Test class.""" def get_properties( - self, ins: GetPropertiesIns, timeout: Optional[float] + self, ins: GetPropertiesIns, timeout: Optional[float], group_id: Optional[int] ) -> GetPropertiesRes: - """Raise an Exception because this method is not expected to be called.""" - raise Exception() + """Raise an error because this method is not expected to be called.""" + raise NotImplementedError() def get_parameters( - self, ins: GetParametersIns, timeout: Optional[float] + self, ins: GetParametersIns, timeout: Optional[float], group_id: Optional[int] ) -> GetParametersRes: - """Raise an Exception because this method is not expected to be called.""" - raise Exception() + """Raise a error because this method is not expected to be called.""" + raise NotImplementedError() - def fit(self, ins: FitIns, timeout: Optional[float]) -> FitRes: + def fit( + self, ins: FitIns, timeout: Optional[float], group_id: Optional[int] + ) -> FitRes: """Simulate fit by returning a success FitRes with simple set of weights.""" arr = np.array([[1, 2], [3, 4], [5, 6]]) arr_serialized = ndarray_to_bytes(arr) @@ -67,7 +69,9 @@ def fit(self, ins: FitIns, timeout: Optional[float]) -> FitRes: metrics={}, ) - def evaluate(self, ins: EvaluateIns, timeout: Optional[float]) -> EvaluateRes: + def evaluate( + self, ins: EvaluateIns, timeout: Optional[float], group_id: Optional[int] + ) -> EvaluateRes: """Simulate evaluate by returning a success EvaluateRes with loss 1.0.""" return EvaluateRes( status=Status(code=Code.OK, message="Success"), @@ -76,7 +80,9 @@ def evaluate(self, ins: EvaluateIns, timeout: Optional[float]) -> EvaluateRes: metrics={}, ) - def reconnect(self, ins: ReconnectIns, timeout: Optional[float]) -> DisconnectRes: + def reconnect( + self, ins: ReconnectIns, timeout: Optional[float], group_id: Optional[int] + ) -> DisconnectRes: """Simulate reconnect by returning a DisconnectRes with UNKNOWN reason.""" return DisconnectRes(reason="UNKNOWN") @@ -85,28 +91,34 @@ class FailingClient(ClientProxy): """Test class.""" def get_properties( - self, ins: GetPropertiesIns, timeout: Optional[float] + self, ins: GetPropertiesIns, timeout: Optional[float], group_id: Optional[int] ) -> GetPropertiesRes: - """Raise an Exception to simulate failure in the client.""" - raise Exception() + """Raise a NotImplementedError to simulate failure in the client.""" + raise NotImplementedError() def get_parameters( - self, ins: GetParametersIns, timeout: Optional[float] + self, ins: GetParametersIns, timeout: Optional[float], group_id: Optional[int] ) -> GetParametersRes: - """Raise an Exception to simulate failure in the client.""" - raise Exception() + """Raise a NotImplementedError to simulate failure in the client.""" + raise NotImplementedError() - def fit(self, ins: FitIns, timeout: Optional[float]) -> FitRes: - """Raise an Exception to simulate failure in the client.""" - raise Exception() + def fit( + self, ins: FitIns, timeout: Optional[float], group_id: Optional[int] + ) -> FitRes: + """Raise a NotImplementedError to simulate failure in the client.""" + raise NotImplementedError() - def evaluate(self, ins: EvaluateIns, timeout: Optional[float]) -> EvaluateRes: - """Raise an Exception to simulate failure in the client.""" - raise Exception() + def evaluate( + self, ins: EvaluateIns, timeout: Optional[float], group_id: Optional[int] + ) -> EvaluateRes: + """Raise a NotImplementedError to simulate failure in the client.""" + raise NotImplementedError() - def reconnect(self, ins: ReconnectIns, timeout: Optional[float]) -> DisconnectRes: - """Raise an Exception to simulate failure in the client.""" - raise Exception() + def reconnect( + self, ins: ReconnectIns, timeout: Optional[float], group_id: Optional[int] + ) -> DisconnectRes: + """Raise a NotImplementedError to simulate failure in the client.""" + raise NotImplementedError() def test_fit_clients() -> None: @@ -122,7 +134,7 @@ def test_fit_clients() -> None: client_instructions = [(c, ins) for c in clients] # Execute - results, failures = fit_clients(client_instructions, None, None) + results, failures = fit_clients(client_instructions, None, None, 0) # Assert assert len(results) == 1 @@ -150,6 +162,7 @@ def test_eval_clients() -> None: client_instructions=client_instructions, max_workers=None, timeout=None, + group_id=0, ) # Assert diff --git a/src/py/flwr/server/strategy/__init__.py b/src/py/flwr/server/strategy/__init__.py index 1750a7522379..b7de9a946fff 100644 --- a/src/py/flwr/server/strategy/__init__.py +++ b/src/py/flwr/server/strategy/__init__.py @@ -16,6 +16,18 @@ from .bulyan import Bulyan as Bulyan +from .dp_adaptive_clipping import ( + DifferentialPrivacyClientSideAdaptiveClipping as DifferentialPrivacyClientSideAdaptiveClipping, +) +from .dp_adaptive_clipping import ( + DifferentialPrivacyServerSideAdaptiveClipping as DifferentialPrivacyServerSideAdaptiveClipping, +) +from .dp_fixed_clipping import ( + DifferentialPrivacyClientSideFixedClipping as DifferentialPrivacyClientSideFixedClipping, +) +from .dp_fixed_clipping import ( + DifferentialPrivacyServerSideFixedClipping as DifferentialPrivacyServerSideFixedClipping, +) from .dpfedavg_adaptive import DPFedAvgAdaptive as DPFedAvgAdaptive from .dpfedavg_fixed import DPFedAvgFixed as DPFedAvgFixed from .fault_tolerant_fedavg import FaultTolerantFedAvg as FaultTolerantFedAvg @@ -37,24 +49,28 @@ from .strategy import Strategy as Strategy __all__ = [ - "FaultTolerantFedAvg", + "Bulyan", + "DPFedAvgAdaptive", + "DPFedAvgFixed", + "DifferentialPrivacyClientSideAdaptiveClipping", + "DifferentialPrivacyServerSideAdaptiveClipping", + "DifferentialPrivacyClientSideFixedClipping", + "DifferentialPrivacyServerSideFixedClipping", "FedAdagrad", "FedAdam", "FedAvg", - "FedXgbNnAvg", - "FedXgbBagging", - "FedXgbCyclic", "FedAvgAndroid", "FedAvgM", + "FedMedian", "FedOpt", "FedProx", - "FedYogi", - "QFedAvg", - "FedMedian", "FedTrimmedAvg", + "FedXgbBagging", + "FedXgbCyclic", + "FedXgbNnAvg", + "FedYogi", + "FaultTolerantFedAvg", "Krum", - "Bulyan", - "DPFedAvgAdaptive", - "DPFedAvgFixed", + "QFedAvg", "Strategy", ] diff --git a/src/py/flwr/server/strategy/aggregate.py b/src/py/flwr/server/strategy/aggregate.py index 63926f2eaa51..c668b55eebe6 100644 --- a/src/py/flwr/server/strategy/aggregate.py +++ b/src/py/flwr/server/strategy/aggregate.py @@ -20,13 +20,14 @@ import numpy as np -from flwr.common import NDArray, NDArrays +from flwr.common import FitRes, NDArray, NDArrays, parameters_to_ndarrays +from flwr.server.client_proxy import ClientProxy def aggregate(results: List[Tuple[NDArrays, int]]) -> NDArrays: """Compute weighted average.""" # Calculate the total number of examples used during training - num_examples_total = sum([num_examples for _, num_examples in results]) + num_examples_total = sum(num_examples for (_, num_examples) in results) # Create a list of weights, each multiplied by the related number of examples weighted_weights = [ @@ -41,6 +42,31 @@ def aggregate(results: List[Tuple[NDArrays, int]]) -> NDArrays: return weights_prime +def aggregate_inplace(results: List[Tuple[ClientProxy, FitRes]]) -> NDArrays: + """Compute in-place weighted average.""" + # Count total examples + num_examples_total = sum(fit_res.num_examples for (_, fit_res) in results) + + # Compute scaling factors for each result + scaling_factors = [ + fit_res.num_examples / num_examples_total for _, fit_res in results + ] + + # Let's do in-place aggregation + # Get first result, then add up each other + params = [ + scaling_factors[0] * x for x in parameters_to_ndarrays(results[0][1].parameters) + ] + for i, (_, fit_res) in enumerate(results[1:]): + res = ( + scaling_factors[i + 1] * x + for x in parameters_to_ndarrays(fit_res.parameters) + ) + params = [reduce(np.add, layer_updates) for layer_updates in zip(params, res)] + + return params + + def aggregate_median(results: List[Tuple[NDArrays, int]]) -> NDArrays: """Compute median.""" # Create a list of weights and ignore the number of examples @@ -69,9 +95,9 @@ def aggregate_krum( # For each client, take the n-f-2 closest parameters vectors num_closest = max(1, len(weights) - num_malicious - 2) closest_indices = [] - for i, _ in enumerate(distance_matrix): + for distance in distance_matrix: closest_indices.append( - np.argsort(distance_matrix[i])[1 : num_closest + 1].tolist() # noqa: E203 + np.argsort(distance)[1 : num_closest + 1].tolist() # noqa: E203 ) # Compute the score for each client, that is the sum of the distances @@ -176,7 +202,7 @@ def aggregate_bulyan( def weighted_loss_avg(results: List[Tuple[int, float]]) -> float: """Aggregate evaluation results obtained from multiple clients.""" - num_total_evaluation_examples = sum([num_examples for num_examples, _ in results]) + num_total_evaluation_examples = sum(num_examples for (num_examples, _) in results) weighted_losses = [num_examples * loss for num_examples, loss in results] return sum(weighted_losses) / num_total_evaluation_examples @@ -207,9 +233,9 @@ def _compute_distances(weights: List[NDArrays]) -> NDArray: """ flat_w = np.array([np.concatenate(p, axis=None).ravel() for p in weights]) distance_matrix = np.zeros((len(weights), len(weights))) - for i, _ in enumerate(flat_w): - for j, _ in enumerate(flat_w): - delta = flat_w[i] - flat_w[j] + for i, flat_w_i in enumerate(flat_w): + for j, flat_w_j in enumerate(flat_w): + delta = flat_w_i - flat_w_j norm = np.linalg.norm(delta) distance_matrix[i, j] = norm**2 return distance_matrix diff --git a/src/py/flwr/server/strategy/dp_adaptive_clipping.py b/src/py/flwr/server/strategy/dp_adaptive_clipping.py new file mode 100644 index 000000000000..1acfd4613a0a --- /dev/null +++ b/src/py/flwr/server/strategy/dp_adaptive_clipping.py @@ -0,0 +1,470 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Central differential privacy with adaptive clipping. + +Paper (Andrew et al.): https://arxiv.org/abs/1905.03871 +""" + + +import math +from logging import INFO, WARNING +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np + +from flwr.common import ( + EvaluateIns, + EvaluateRes, + FitIns, + FitRes, + NDArrays, + Parameters, + Scalar, + ndarrays_to_parameters, + parameters_to_ndarrays, +) +from flwr.common.differential_privacy import ( + adaptive_clip_inputs_inplace, + add_gaussian_noise_to_params, + compute_adaptive_noise_params, + compute_stdv, +) +from flwr.common.differential_privacy_constants import ( + CLIENTS_DISCREPANCY_WARNING, + KEY_CLIPPING_NORM, + KEY_NORM_BIT, +) +from flwr.common.logger import log +from flwr.server.client_manager import ClientManager +from flwr.server.client_proxy import ClientProxy +from flwr.server.strategy.strategy import Strategy + + +class DifferentialPrivacyServerSideAdaptiveClipping(Strategy): + """Strategy wrapper for central DP with server-side adaptive clipping. + + Parameters + ---------- + strategy: Strategy + The strategy to which DP functionalities will be added by this wrapper. + noise_multiplier : float + The noise multiplier for the Gaussian mechanism for model updates. + num_sampled_clients : int + The number of clients that are sampled on each round. + initial_clipping_norm : float + The initial value of clipping norm. Defaults to 0.1. + Andrew et al. recommends to set to 0.1. + target_clipped_quantile : float + The desired quantile of updates which should be clipped. Defaults to 0.5. + clip_norm_lr : float + The learning rate for the clipping norm adaptation. Defaults to 0.2. + Andrew et al. recommends to set to 0.2. + clipped_count_stddev : float + The standard deviation of the noise added to the count of updates below the estimate. + Andrew et al. recommends to set to `expected_num_records/20` + + Examples + -------- + Create a strategy: + + >>> strategy = fl.server.strategy.FedAvg( ... ) + + Wrap the strategy with the DifferentialPrivacyServerSideAdaptiveClipping wrapper + + >>> dp_strategy = DifferentialPrivacyServerSideAdaptiveClipping( + >>> strategy, cfg.noise_multiplier, cfg.num_sampled_clients, ... + >>> ) + """ + + # pylint: disable=too-many-arguments,too-many-instance-attributes + def __init__( + self, + strategy: Strategy, + noise_multiplier: float, + num_sampled_clients: int, + initial_clipping_norm: float = 0.1, + target_clipped_quantile: float = 0.5, + clip_norm_lr: float = 0.2, + clipped_count_stddev: Optional[float] = None, + ) -> None: + super().__init__() + + if strategy is None: + raise ValueError("The passed strategy is None.") + + if noise_multiplier < 0: + raise ValueError("The noise multiplier should be a non-negative value.") + + if num_sampled_clients <= 0: + raise ValueError( + "The number of sampled clients should be a positive value." + ) + + if initial_clipping_norm <= 0: + raise ValueError("The initial clipping norm should be a positive value.") + + if not 0 <= target_clipped_quantile <= 1: + raise ValueError( + "The target clipped quantile must be between 0 and 1 (inclusive)." + ) + + if clip_norm_lr <= 0: + raise ValueError("The learning rate must be positive.") + + if clipped_count_stddev is not None: + if clipped_count_stddev < 0: + raise ValueError("The `clipped_count_stddev` must be non-negative.") + + self.strategy = strategy + self.num_sampled_clients = num_sampled_clients + self.clipping_norm = initial_clipping_norm + self.target_clipped_quantile = target_clipped_quantile + self.clip_norm_lr = clip_norm_lr + ( + self.clipped_count_stddev, + self.noise_multiplier, + ) = compute_adaptive_noise_params( + noise_multiplier, + num_sampled_clients, + clipped_count_stddev, + ) + + self.current_round_params: NDArrays = [] + + def __repr__(self) -> str: + """Compute a string representation of the strategy.""" + rep = "Differential Privacy Strategy Wrapper (Server-Side Adaptive Clipping)" + return rep + + def initialize_parameters( + self, client_manager: ClientManager + ) -> Optional[Parameters]: + """Initialize global model parameters using given strategy.""" + return self.strategy.initialize_parameters(client_manager) + + def configure_fit( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, FitIns]]: + """Configure the next round of training.""" + self.current_round_params = parameters_to_ndarrays(parameters) + return self.strategy.configure_fit(server_round, parameters, client_manager) + + def configure_evaluate( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, EvaluateIns]]: + """Configure the next round of evaluation.""" + return self.strategy.configure_evaluate( + server_round, parameters, client_manager + ) + + def aggregate_fit( + self, + server_round: int, + results: List[Tuple[ClientProxy, FitRes]], + failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], + ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + """Aggregate training results and update clip norms.""" + if failures: + return None, {} + + if len(results) != self.num_sampled_clients: + log( + WARNING, + CLIENTS_DISCREPANCY_WARNING, + len(results), + self.num_sampled_clients, + ) + + norm_bit_set_count = 0 + for _, res in results: + param = parameters_to_ndarrays(res.parameters) + # Compute and clip update + model_update = [ + np.subtract(x, y) for (x, y) in zip(param, self.current_round_params) + ] + + norm_bit = adaptive_clip_inputs_inplace(model_update, self.clipping_norm) + norm_bit_set_count += norm_bit + + log( + INFO, + "aggregate_fit: parameters are clipped by value: %s.", + self.clipping_norm, + ) + + for i, _ in enumerate(self.current_round_params): + param[i] = self.current_round_params[i] + model_update[i] + # Convert back to parameters + res.parameters = ndarrays_to_parameters(param) + + # Noising the count + noised_norm_bit_set_count = float( + np.random.normal(norm_bit_set_count, self.clipped_count_stddev) + ) + noised_norm_bit_set_fraction = noised_norm_bit_set_count / len(results) + # Geometric update + self.clipping_norm *= math.exp( + -self.clip_norm_lr + * (noised_norm_bit_set_fraction - self.target_clipped_quantile) + ) + + aggregated_params, metrics = self.strategy.aggregate_fit( + server_round, results, failures + ) + + # Add Gaussian noise to the aggregated parameters + if aggregated_params: + aggregated_params = add_gaussian_noise_to_params( + aggregated_params, + self.noise_multiplier, + self.clipping_norm, + self.num_sampled_clients, + ) + log( + INFO, + "aggregate_fit: central DP noise with standard deviation: %s added to parameters.", + compute_stdv( + self.noise_multiplier, self.clipping_norm, self.num_sampled_clients + ), + ) + + return aggregated_params, metrics + + def aggregate_evaluate( + self, + server_round: int, + results: List[Tuple[ClientProxy, EvaluateRes]], + failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> Tuple[Optional[float], Dict[str, Scalar]]: + """Aggregate evaluation losses using the given strategy.""" + return self.strategy.aggregate_evaluate(server_round, results, failures) + + def evaluate( + self, server_round: int, parameters: Parameters + ) -> Optional[Tuple[float, Dict[str, Scalar]]]: + """Evaluate model parameters using an evaluation function from the strategy.""" + return self.strategy.evaluate(server_round, parameters) + + +class DifferentialPrivacyClientSideAdaptiveClipping(Strategy): + """Strategy wrapper for central DP with client-side adaptive clipping. + + Use `adaptiveclipping_mod` modifier at the client side. + + In comparison to `DifferentialPrivacyServerSideAdaptiveClipping`, + which performs clipping on the server-side, `DifferentialPrivacyClientSideAdaptiveClipping` + expects clipping to happen on the client-side, usually by using the built-in + `adaptiveclipping_mod`. + + Parameters + ---------- + strategy : Strategy + The strategy to which DP functionalities will be added by this wrapper. + noise_multiplier : float + The noise multiplier for the Gaussian mechanism for model updates. + num_sampled_clients : int + The number of clients that are sampled on each round. + initial_clipping_norm : float + The initial value of clipping norm. Defaults to 0.1. + Andrew et al. recommends to set to 0.1. + target_clipped_quantile : float + The desired quantile of updates which should be clipped. Defaults to 0.5. + clip_norm_lr : float + The learning rate for the clipping norm adaptation. Defaults to 0.2. + Andrew et al. recommends to set to 0.2. + clipped_count_stddev : float + The stddev of the noise added to the count of updates currently below the estimate. + Andrew et al. recommends to set to `expected_num_records/20` + + Examples + -------- + Create a strategy: + + >>> strategy = fl.server.strategy.FedAvg(...) + + Wrap the strategy with the `DifferentialPrivacyClientSideAdaptiveClipping` wrapper: + + >>> dp_strategy = DifferentialPrivacyClientSideAdaptiveClipping( + >>> strategy, cfg.noise_multiplier, cfg.num_sampled_clients + >>> ) + + On the client, add the `adaptiveclipping_mod` to the client-side mods: + + >>> app = fl.client.ClientApp( + >>> client_fn=client_fn, mods=[adaptiveclipping_mod] + >>> ) + """ + + # pylint: disable=too-many-arguments,too-many-instance-attributes + def __init__( + self, + strategy: Strategy, + noise_multiplier: float, + num_sampled_clients: int, + initial_clipping_norm: float = 0.1, + target_clipped_quantile: float = 0.5, + clip_norm_lr: float = 0.2, + clipped_count_stddev: Optional[float] = None, + ) -> None: + super().__init__() + + if strategy is None: + raise ValueError("The passed strategy is None.") + + if noise_multiplier < 0: + raise ValueError("The noise multiplier should be a non-negative value.") + + if num_sampled_clients <= 0: + raise ValueError( + "The number of sampled clients should be a positive value." + ) + + if initial_clipping_norm <= 0: + raise ValueError("The initial clipping norm should be a positive value.") + + if not 0 <= target_clipped_quantile <= 1: + raise ValueError( + "The target clipped quantile must be between 0 and 1 (inclusive)." + ) + + if clip_norm_lr <= 0: + raise ValueError("The learning rate must be positive.") + + if clipped_count_stddev is not None and clipped_count_stddev < 0: + raise ValueError("The `clipped_count_stddev` must be non-negative.") + + self.strategy = strategy + self.num_sampled_clients = num_sampled_clients + self.clipping_norm = initial_clipping_norm + self.target_clipped_quantile = target_clipped_quantile + self.clip_norm_lr = clip_norm_lr + ( + self.clipped_count_stddev, + self.noise_multiplier, + ) = compute_adaptive_noise_params( + noise_multiplier, + num_sampled_clients, + clipped_count_stddev, + ) + + def __repr__(self) -> str: + """Compute a string representation of the strategy.""" + rep = "Differential Privacy Strategy Wrapper (Client-Side Adaptive Clipping)" + return rep + + def initialize_parameters( + self, client_manager: ClientManager + ) -> Optional[Parameters]: + """Initialize global model parameters using given strategy.""" + return self.strategy.initialize_parameters(client_manager) + + def configure_fit( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, FitIns]]: + """Configure the next round of training.""" + additional_config = {KEY_CLIPPING_NORM: self.clipping_norm} + inner_strategy_config_result = self.strategy.configure_fit( + server_round, parameters, client_manager + ) + for _, fit_ins in inner_strategy_config_result: + fit_ins.config.update(additional_config) + + return inner_strategy_config_result + + def configure_evaluate( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, EvaluateIns]]: + """Configure the next round of evaluation.""" + return self.strategy.configure_evaluate( + server_round, parameters, client_manager + ) + + def aggregate_fit( + self, + server_round: int, + results: List[Tuple[ClientProxy, FitRes]], + failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], + ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + """Aggregate training results and update clip norms.""" + if failures: + return None, {} + + if len(results) != self.num_sampled_clients: + log( + WARNING, + CLIENTS_DISCREPANCY_WARNING, + len(results), + self.num_sampled_clients, + ) + + aggregated_params, metrics = self.strategy.aggregate_fit( + server_round, results, failures + ) + self._update_clip_norm(results) + + # Add Gaussian noise to the aggregated parameters + if aggregated_params: + aggregated_params = add_gaussian_noise_to_params( + aggregated_params, + self.noise_multiplier, + self.clipping_norm, + self.num_sampled_clients, + ) + log( + INFO, + "aggregate_fit: central DP noise with standard deviation: %s added to parameters.", + compute_stdv( + self.noise_multiplier, self.clipping_norm, self.num_sampled_clients + ), + ) + + return aggregated_params, metrics + + def _update_clip_norm(self, results: List[Tuple[ClientProxy, FitRes]]) -> None: + # Calculate the number of clients which set the norm indicator bit + norm_bit_set_count = 0 + for client_proxy, fit_res in results: + if KEY_NORM_BIT not in fit_res.metrics: + raise KeyError( + f"{KEY_NORM_BIT} not returned by client with id {client_proxy.cid}." + ) + if fit_res.metrics[KEY_NORM_BIT]: + norm_bit_set_count += 1 + # Add noise to the count + noised_norm_bit_set_count = float( + np.random.normal(norm_bit_set_count, self.clipped_count_stddev) + ) + + noised_norm_bit_set_fraction = noised_norm_bit_set_count / len(results) + # Geometric update + self.clipping_norm *= math.exp( + -self.clip_norm_lr + * (noised_norm_bit_set_fraction - self.target_clipped_quantile) + ) + + def aggregate_evaluate( + self, + server_round: int, + results: List[Tuple[ClientProxy, EvaluateRes]], + failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> Tuple[Optional[float], Dict[str, Scalar]]: + """Aggregate evaluation losses using the given strategy.""" + return self.strategy.aggregate_evaluate(server_round, results, failures) + + def evaluate( + self, server_round: int, parameters: Parameters + ) -> Optional[Tuple[float, Dict[str, Scalar]]]: + """Evaluate model parameters using an evaluation function from the strategy.""" + return self.strategy.evaluate(server_round, parameters) diff --git a/src/py/flwr/server/strategy/dp_fixed_clipping.py b/src/py/flwr/server/strategy/dp_fixed_clipping.py new file mode 100644 index 000000000000..61e8123e28d7 --- /dev/null +++ b/src/py/flwr/server/strategy/dp_fixed_clipping.py @@ -0,0 +1,360 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Central differential privacy with fixed clipping. + +Papers: https://arxiv.org/abs/1712.07557, https://arxiv.org/abs/1710.06963 +""" + + +from logging import INFO, WARNING +from typing import Dict, List, Optional, Tuple, Union + +from flwr.common import ( + EvaluateIns, + EvaluateRes, + FitIns, + FitRes, + NDArrays, + Parameters, + Scalar, + ndarrays_to_parameters, + parameters_to_ndarrays, +) +from flwr.common.differential_privacy import ( + add_gaussian_noise_to_params, + compute_clip_model_update, + compute_stdv, +) +from flwr.common.differential_privacy_constants import ( + CLIENTS_DISCREPANCY_WARNING, + KEY_CLIPPING_NORM, +) +from flwr.common.logger import log +from flwr.server.client_manager import ClientManager +from flwr.server.client_proxy import ClientProxy +from flwr.server.strategy.strategy import Strategy + + +class DifferentialPrivacyServerSideFixedClipping(Strategy): + """Strategy wrapper for central DP with server-side fixed clipping. + + Parameters + ---------- + strategy : Strategy + The strategy to which DP functionalities will be added by this wrapper. + noise_multiplier : float + The noise multiplier for the Gaussian mechanism for model updates. + A value of 1.0 or higher is recommended for strong privacy. + clipping_norm : float + The value of the clipping norm. + num_sampled_clients : int + The number of clients that are sampled on each round. + + Examples + -------- + Create a strategy: + + >>> strategy = fl.server.strategy.FedAvg( ... ) + + Wrap the strategy with the DifferentialPrivacyServerSideFixedClipping wrapper + + >>> dp_strategy = DifferentialPrivacyServerSideFixedClipping( + >>> strategy, cfg.noise_multiplier, cfg.clipping_norm, cfg.num_sampled_clients + >>> ) + """ + + # pylint: disable=too-many-arguments,too-many-instance-attributes + def __init__( + self, + strategy: Strategy, + noise_multiplier: float, + clipping_norm: float, + num_sampled_clients: int, + ) -> None: + super().__init__() + + self.strategy = strategy + + if noise_multiplier < 0: + raise ValueError("The noise multiplier should be a non-negative value.") + + if clipping_norm <= 0: + raise ValueError("The clipping norm should be a positive value.") + + if num_sampled_clients <= 0: + raise ValueError( + "The number of sampled clients should be a positive value." + ) + + self.noise_multiplier = noise_multiplier + self.clipping_norm = clipping_norm + self.num_sampled_clients = num_sampled_clients + + self.current_round_params: NDArrays = [] + + def __repr__(self) -> str: + """Compute a string representation of the strategy.""" + rep = "Differential Privacy Strategy Wrapper (Server-Side Fixed Clipping)" + return rep + + def initialize_parameters( + self, client_manager: ClientManager + ) -> Optional[Parameters]: + """Initialize global model parameters using given strategy.""" + return self.strategy.initialize_parameters(client_manager) + + def configure_fit( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, FitIns]]: + """Configure the next round of training.""" + self.current_round_params = parameters_to_ndarrays(parameters) + return self.strategy.configure_fit(server_round, parameters, client_manager) + + def configure_evaluate( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, EvaluateIns]]: + """Configure the next round of evaluation.""" + return self.strategy.configure_evaluate( + server_round, parameters, client_manager + ) + + def aggregate_fit( + self, + server_round: int, + results: List[Tuple[ClientProxy, FitRes]], + failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], + ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + """Compute the updates, clip, and pass them for aggregation. + + Afterward, add noise to the aggregated parameters. + """ + if failures: + return None, {} + + if len(results) != self.num_sampled_clients: + log( + WARNING, + CLIENTS_DISCREPANCY_WARNING, + len(results), + self.num_sampled_clients, + ) + for _, res in results: + param = parameters_to_ndarrays(res.parameters) + # Compute and clip update + compute_clip_model_update( + param, self.current_round_params, self.clipping_norm + ) + log( + INFO, + "aggregate_fit: parameters are clipped by value: %s.", + self.clipping_norm, + ) + # Convert back to parameters + res.parameters = ndarrays_to_parameters(param) + + # Pass the new parameters for aggregation + aggregated_params, metrics = self.strategy.aggregate_fit( + server_round, results, failures + ) + + # Add Gaussian noise to the aggregated parameters + if aggregated_params: + aggregated_params = add_gaussian_noise_to_params( + aggregated_params, + self.noise_multiplier, + self.clipping_norm, + self.num_sampled_clients, + ) + + log( + INFO, + "aggregate_fit: central DP noise with standard deviation: %s added to parameters.", + compute_stdv( + self.noise_multiplier, self.clipping_norm, self.num_sampled_clients + ), + ) + + return aggregated_params, metrics + + def aggregate_evaluate( + self, + server_round: int, + results: List[Tuple[ClientProxy, EvaluateRes]], + failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> Tuple[Optional[float], Dict[str, Scalar]]: + """Aggregate evaluation losses using the given strategy.""" + return self.strategy.aggregate_evaluate(server_round, results, failures) + + def evaluate( + self, server_round: int, parameters: Parameters + ) -> Optional[Tuple[float, Dict[str, Scalar]]]: + """Evaluate model parameters using an evaluation function from the strategy.""" + return self.strategy.evaluate(server_round, parameters) + + +class DifferentialPrivacyClientSideFixedClipping(Strategy): + """Strategy wrapper for central DP with client-side fixed clipping. + + Use `fixedclipping_mod` modifier at the client side. + + In comparison to `DifferentialPrivacyServerSideFixedClipping`, + which performs clipping on the server-side, `DifferentialPrivacyClientSideFixedClipping` + expects clipping to happen on the client-side, usually by using the built-in + `fixedclipping_mod`. + + Parameters + ---------- + strategy : Strategy + The strategy to which DP functionalities will be added by this wrapper. + noise_multiplier : float + The noise multiplier for the Gaussian mechanism for model updates. + A value of 1.0 or higher is recommended for strong privacy. + clipping_norm : float + The value of the clipping norm. + num_sampled_clients : int + The number of clients that are sampled on each round. + + Examples + -------- + Create a strategy: + + >>> strategy = fl.server.strategy.FedAvg(...) + + Wrap the strategy with the `DifferentialPrivacyClientSideFixedClipping` wrapper: + + >>> dp_strategy = DifferentialPrivacyClientSideFixedClipping( + >>> strategy, cfg.noise_multiplier, cfg.clipping_norm, cfg.num_sampled_clients + >>> ) + + On the client, add the `fixedclipping_mod` to the client-side mods: + + >>> app = fl.client.ClientApp( + >>> client_fn=client_fn, mods=[fixedclipping_mod] + >>> ) + """ + + # pylint: disable=too-many-arguments,too-many-instance-attributes + def __init__( + self, + strategy: Strategy, + noise_multiplier: float, + clipping_norm: float, + num_sampled_clients: int, + ) -> None: + super().__init__() + + self.strategy = strategy + + if noise_multiplier < 0: + raise ValueError("The noise multiplier should be a non-negative value.") + + if clipping_norm <= 0: + raise ValueError("The clipping threshold should be a positive value.") + + if num_sampled_clients <= 0: + raise ValueError( + "The number of sampled clients should be a positive value." + ) + + self.noise_multiplier = noise_multiplier + self.clipping_norm = clipping_norm + self.num_sampled_clients = num_sampled_clients + + def __repr__(self) -> str: + """Compute a string representation of the strategy.""" + rep = "Differential Privacy Strategy Wrapper (Client-Side Fixed Clipping)" + return rep + + def initialize_parameters( + self, client_manager: ClientManager + ) -> Optional[Parameters]: + """Initialize global model parameters using given strategy.""" + return self.strategy.initialize_parameters(client_manager) + + def configure_fit( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, FitIns]]: + """Configure the next round of training.""" + additional_config = {KEY_CLIPPING_NORM: self.clipping_norm} + inner_strategy_config_result = self.strategy.configure_fit( + server_round, parameters, client_manager + ) + for _, fit_ins in inner_strategy_config_result: + fit_ins.config.update(additional_config) + + return inner_strategy_config_result + + def configure_evaluate( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, EvaluateIns]]: + """Configure the next round of evaluation.""" + return self.strategy.configure_evaluate( + server_round, parameters, client_manager + ) + + def aggregate_fit( + self, + server_round: int, + results: List[Tuple[ClientProxy, FitRes]], + failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], + ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + """Add noise to the aggregated parameters.""" + if failures: + return None, {} + + if len(results) != self.num_sampled_clients: + log( + WARNING, + CLIENTS_DISCREPANCY_WARNING, + len(results), + self.num_sampled_clients, + ) + + # Pass the new parameters for aggregation + aggregated_params, metrics = self.strategy.aggregate_fit( + server_round, results, failures + ) + + # Add Gaussian noise to the aggregated parameters + if aggregated_params: + aggregated_params = add_gaussian_noise_to_params( + aggregated_params, + self.noise_multiplier, + self.clipping_norm, + self.num_sampled_clients, + ) + log( + INFO, + "aggregate_fit: central DP noise with standard deviation: %s added to parameters.", + compute_stdv( + self.noise_multiplier, self.clipping_norm, self.num_sampled_clients + ), + ) + return aggregated_params, metrics + + def aggregate_evaluate( + self, + server_round: int, + results: List[Tuple[ClientProxy, EvaluateRes]], + failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> Tuple[Optional[float], Dict[str, Scalar]]: + """Aggregate evaluation losses using the given strategy.""" + return self.strategy.aggregate_evaluate(server_round, results, failures) + + def evaluate( + self, server_round: int, parameters: Parameters + ) -> Optional[Tuple[float, Dict[str, Scalar]]]: + """Evaluate model parameters using an evaluation function from the strategy.""" + return self.strategy.evaluate(server_round, parameters) diff --git a/src/py/flwr/server/strategy/dpfedavg_adaptive.py b/src/py/flwr/server/strategy/dpfedavg_adaptive.py index 3269735e9d73..a908679ed668 100644 --- a/src/py/flwr/server/strategy/dpfedavg_adaptive.py +++ b/src/py/flwr/server/strategy/dpfedavg_adaptive.py @@ -24,6 +24,7 @@ import numpy as np from flwr.common import FitIns, FitRes, Parameters, Scalar +from flwr.common.logger import warn_deprecated_feature from flwr.server.client_manager import ClientManager from flwr.server.client_proxy import ClientProxy from flwr.server.strategy.dpfedavg_fixed import DPFedAvgFixed @@ -31,7 +32,12 @@ class DPFedAvgAdaptive(DPFedAvgFixed): - """Wrapper for configuring a Strategy for DP with Adaptive Clipping.""" + """Wrapper for configuring a Strategy for DP with Adaptive Clipping. + + Warning + ------- + This class is deprecated and will be removed in a future release. + """ # pylint: disable=too-many-arguments,too-many-instance-attributes def __init__( @@ -45,6 +51,7 @@ def __init__( clip_norm_target_quantile: float = 0.5, clip_count_stddev: Optional[float] = None, ) -> None: + warn_deprecated_feature("`DPFedAvgAdaptive` wrapper") super().__init__( strategy=strategy, num_sampled_clients=num_sampled_clients, @@ -91,7 +98,7 @@ def _update_clip_norm(self, results: List[Tuple[ClientProxy, FitRes]]) -> None: norm_bit_set_count = 0 for client_proxy, fit_res in results: if "dpfedavg_norm_bit" not in fit_res.metrics: - raise Exception( + raise KeyError( f"Indicator bit not returned by client with id {client_proxy.cid}." ) if fit_res.metrics["dpfedavg_norm_bit"]: diff --git a/src/py/flwr/server/strategy/dpfedavg_fixed.py b/src/py/flwr/server/strategy/dpfedavg_fixed.py index 0154cfd79fc5..c54379fc7087 100644 --- a/src/py/flwr/server/strategy/dpfedavg_fixed.py +++ b/src/py/flwr/server/strategy/dpfedavg_fixed.py @@ -17,11 +17,11 @@ Paper: arxiv.org/pdf/1710.06963.pdf """ - from typing import Dict, List, Optional, Tuple, Union from flwr.common import EvaluateIns, EvaluateRes, FitIns, FitRes, Parameters, Scalar from flwr.common.dp import add_gaussian_noise +from flwr.common.logger import warn_deprecated_feature from flwr.common.parameter import ndarrays_to_parameters, parameters_to_ndarrays from flwr.server.client_manager import ClientManager from flwr.server.client_proxy import ClientProxy @@ -29,7 +29,12 @@ class DPFedAvgFixed(Strategy): - """Wrapper for configuring a Strategy for DP with Fixed Clipping.""" + """Wrapper for configuring a Strategy for DP with Fixed Clipping. + + Warning + ------- + This class is deprecated and will be removed in a future release. + """ # pylint: disable=too-many-arguments,too-many-instance-attributes def __init__( @@ -40,17 +45,18 @@ def __init__( noise_multiplier: float = 1, server_side_noising: bool = True, ) -> None: + warn_deprecated_feature("`DPFedAvgFixed` wrapper") super().__init__() self.strategy = strategy # Doing fixed-size subsampling as in https://arxiv.org/abs/1905.03871. self.num_sampled_clients = num_sampled_clients if clip_norm <= 0: - raise Exception("The clipping threshold should be a positive value.") + raise ValueError("The clipping threshold should be a positive value.") self.clip_norm = clip_norm if noise_multiplier < 0: - raise Exception("The noise multiplier should be a non-negative value.") + raise ValueError("The noise multiplier should be a non-negative value.") self.noise_multiplier = noise_multiplier self.server_side_noising = server_side_noising @@ -98,9 +104,9 @@ def configure_fit( """ additional_config = {"dpfedavg_clip_norm": self.clip_norm} if not self.server_side_noising: - additional_config[ - "dpfedavg_noise_stddev" - ] = self._calc_client_noise_stddev() + additional_config["dpfedavg_noise_stddev"] = ( + self._calc_client_noise_stddev() + ) client_instructions = self.strategy.configure_fit( server_round, parameters, client_manager diff --git a/src/py/flwr/server/strategy/fedadagrad_test.py b/src/py/flwr/server/strategy/fedadagrad_test.py index b3380a5be2f9..0c966442ecaf 100644 --- a/src/py/flwr/server/strategy/fedadagrad_test.py +++ b/src/py/flwr/server/strategy/fedadagrad_test.py @@ -30,7 +30,7 @@ parameters_to_ndarrays, ) from flwr.server.client_proxy import ClientProxy -from flwr.server.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy +from flwr.server.superlink.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy from .fedadagrad import FedAdagrad diff --git a/src/py/flwr/server/strategy/fedavg.py b/src/py/flwr/server/strategy/fedavg.py index c93c8cb8b83e..3b9b2640c2b5 100644 --- a/src/py/flwr/server/strategy/fedavg.py +++ b/src/py/flwr/server/strategy/fedavg.py @@ -37,7 +37,7 @@ from flwr.server.client_manager import ClientManager from flwr.server.client_proxy import ClientProxy -from .aggregate import aggregate, weighted_loss_avg +from .aggregate import aggregate, aggregate_inplace, weighted_loss_avg from .strategy import Strategy WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW = """ @@ -84,6 +84,8 @@ class FedAvg(Strategy): Metrics aggregation function, optional. evaluate_metrics_aggregation_fn : Optional[MetricsAggregationFn] Metrics aggregation function, optional. + inplace : bool (default: True) + Enable (True) or disable (False) in-place aggregation of model updates. """ # pylint: disable=too-many-arguments,too-many-instance-attributes, line-too-long @@ -107,6 +109,7 @@ def __init__( initial_parameters: Optional[Parameters] = None, fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, + inplace: bool = True, ) -> None: super().__init__() @@ -128,6 +131,7 @@ def __init__( self.initial_parameters = initial_parameters self.fit_metrics_aggregation_fn = fit_metrics_aggregation_fn self.evaluate_metrics_aggregation_fn = evaluate_metrics_aggregation_fn + self.inplace = inplace def __repr__(self) -> str: """Compute a string representation of the strategy.""" @@ -226,12 +230,18 @@ def aggregate_fit( if not self.accept_failures and failures: return None, {} - # Convert results - weights_results = [ - (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) - for _, fit_res in results - ] - parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results)) + if self.inplace: + # Does in-place weighted average of results + aggregated_ndarrays = aggregate_inplace(results) + else: + # Convert results + weights_results = [ + (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) + for _, fit_res in results + ] + aggregated_ndarrays = aggregate(weights_results) + + parameters_aggregated = ndarrays_to_parameters(aggregated_ndarrays) # Aggregate custom metrics if aggregation fn was provided metrics_aggregated = {} diff --git a/src/py/flwr/server/strategy/fedavg_android.py b/src/py/flwr/server/strategy/fedavg_android.py index e890f7216020..6678b7ced114 100644 --- a/src/py/flwr/server/strategy/fedavg_android.py +++ b/src/py/flwr/server/strategy/fedavg_android.py @@ -234,12 +234,10 @@ def parameters_to_ndarrays(self, parameters: Parameters) -> NDArrays: """Convert parameters object to NumPy weights.""" return [self.bytes_to_ndarray(tensor) for tensor in parameters.tensors] - # pylint: disable=R0201 def ndarray_to_bytes(self, ndarray: NDArray) -> bytes: """Serialize NumPy array to bytes.""" return ndarray.tobytes() - # pylint: disable=R0201 def bytes_to_ndarray(self, tensor: bytes) -> NDArray: """Deserialize NumPy array from bytes.""" ndarray_deserialized = np.frombuffer(tensor, dtype=np.float32) diff --git a/src/py/flwr/server/strategy/fedavg_test.py b/src/py/flwr/server/strategy/fedavg_test.py index 947736f4a571..e62eaa5c5832 100644 --- a/src/py/flwr/server/strategy/fedavg_test.py +++ b/src/py/flwr/server/strategy/fedavg_test.py @@ -15,6 +15,16 @@ """FedAvg tests.""" +from typing import List, Tuple, Union +from unittest.mock import MagicMock + +import numpy as np +from numpy.testing import assert_allclose + +from flwr.common import Code, FitRes, Status, parameters_to_ndarrays +from flwr.common.parameter import ndarrays_to_parameters +from flwr.server.client_proxy import ClientProxy + from .fedavg import FedAvg @@ -120,3 +130,51 @@ def test_fedavg_num_evaluation_clients_minimum() -> None: # Assert assert expected == actual + + +def test_inplace_aggregate_fit_equivalence() -> None: + """Test aggregate_fit equivalence between FedAvg and its inplace version.""" + # Prepare + weights0_0 = np.random.randn(100, 64) + weights0_1 = np.random.randn(314, 628, 3) + weights1_0 = np.random.randn(100, 64) + weights1_1 = np.random.randn(314, 628, 3) + + results: List[Tuple[ClientProxy, FitRes]] = [ + ( + MagicMock(), + FitRes( + status=Status(code=Code.OK, message="Success"), + parameters=ndarrays_to_parameters([weights0_0, weights0_1]), + num_examples=1, + metrics={}, + ), + ), + ( + MagicMock(), + FitRes( + status=Status(code=Code.OK, message="Success"), + parameters=ndarrays_to_parameters([weights1_0, weights1_1]), + num_examples=5, + metrics={}, + ), + ), + ] + failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]] = [] + + fedavg_reference = FedAvg(inplace=False) + fedavg_inplace = FedAvg() + + # Execute + reference, _ = fedavg_reference.aggregate_fit(1, results, failures) + assert reference + inplace, _ = fedavg_inplace.aggregate_fit(1, results, failures) + assert inplace + + # Convert to NumPy to check similarity + reference_np = parameters_to_ndarrays(reference) + inplace_np = parameters_to_ndarrays(inplace) + + # Assert + for ref, inp in zip(reference_np, inplace_np): + assert_allclose(ref, inp) diff --git a/src/py/flwr/server/strategy/fedmedian.py b/src/py/flwr/server/strategy/fedmedian.py index 7a5bf1425b44..17e979d92beb 100644 --- a/src/py/flwr/server/strategy/fedmedian.py +++ b/src/py/flwr/server/strategy/fedmedian.py @@ -36,7 +36,7 @@ class FedMedian(FedAvg): - """Configurable FedAvg with Momentum strategy implementation.""" + """Configurable FedMedian strategy implementation.""" def __repr__(self) -> str: """Compute a string representation of the strategy.""" diff --git a/src/py/flwr/server/strategy/fedmedian_test.py b/src/py/flwr/server/strategy/fedmedian_test.py index 180503df6c80..57cf08d8c01d 100644 --- a/src/py/flwr/server/strategy/fedmedian_test.py +++ b/src/py/flwr/server/strategy/fedmedian_test.py @@ -30,7 +30,7 @@ parameters_to_ndarrays, ) from flwr.server.client_proxy import ClientProxy -from flwr.server.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy +from flwr.server.superlink.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy from .fedmedian import FedMedian diff --git a/src/py/flwr/server/strategy/fedxgb_bagging.py b/src/py/flwr/server/strategy/fedxgb_bagging.py index cafb466c2e8b..a8e8adddafbb 100644 --- a/src/py/flwr/server/strategy/fedxgb_bagging.py +++ b/src/py/flwr/server/strategy/fedxgb_bagging.py @@ -44,6 +44,11 @@ def __init__( self.global_model: Optional[bytes] = None super().__init__(**kwargs) + def __repr__(self) -> str: + """Compute a string representation of the strategy.""" + rep = f"FedXgbBagging(accept_failures={self.accept_failures})" + return rep + def aggregate_fit( self, server_round: int, diff --git a/src/py/flwr/server/strategy/fedxgb_cyclic.py b/src/py/flwr/server/strategy/fedxgb_cyclic.py index e2707b02d19d..2605daab29f4 100644 --- a/src/py/flwr/server/strategy/fedxgb_cyclic.py +++ b/src/py/flwr/server/strategy/fedxgb_cyclic.py @@ -37,6 +37,11 @@ def __init__( self.global_model: Optional[bytes] = None super().__init__(**kwargs) + def __repr__(self) -> str: + """Compute a string representation of the strategy.""" + rep = f"FedXgbCyclic(accept_failures={self.accept_failures})" + return rep + def aggregate_fit( self, server_round: int, diff --git a/src/py/flwr/server/strategy/fedxgb_nn_avg.py b/src/py/flwr/server/strategy/fedxgb_nn_avg.py index f300633d0d9f..8dedc925f350 100644 --- a/src/py/flwr/server/strategy/fedxgb_nn_avg.py +++ b/src/py/flwr/server/strategy/fedxgb_nn_avg.py @@ -25,7 +25,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union from flwr.common import FitRes, Scalar, ndarrays_to_parameters, parameters_to_ndarrays -from flwr.common.logger import log +from flwr.common.logger import log, warn_deprecated_feature from flwr.server.client_proxy import ClientProxy from .aggregate import aggregate @@ -33,7 +33,13 @@ class FedXgbNnAvg(FedAvg): - """Configurable FedXgbNnAvg strategy implementation.""" + """Configurable FedXgbNnAvg strategy implementation. + + Warning + ------- + This strategy is deprecated, but a copy of it is available in Flower Baselines: + https://github.com/adap/flower/tree/main/baselines/hfedxgboost. + """ def __init__(self, *args: Any, **kwargs: Any) -> None: """Federated XGBoost [Ma et al., 2023] strategy. @@ -41,6 +47,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: Implementation based on https://arxiv.org/abs/2304.07537. """ super().__init__(*args, **kwargs) + warn_deprecated_feature("`FedXgbNnAvg` strategy") def __repr__(self) -> str: """Compute a string representation of the strategy.""" diff --git a/src/py/flwr/server/strategy/krum_test.py b/src/py/flwr/server/strategy/krum_test.py index 81e59230739a..653dc9a8475d 100644 --- a/src/py/flwr/server/strategy/krum_test.py +++ b/src/py/flwr/server/strategy/krum_test.py @@ -30,7 +30,7 @@ parameters_to_ndarrays, ) from flwr.server.client_proxy import ClientProxy -from flwr.server.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy +from flwr.server.superlink.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy from .krum import Krum diff --git a/src/py/flwr/server/strategy/multikrum_test.py b/src/py/flwr/server/strategy/multikrum_test.py index 1469db104252..f874dc2f9800 100644 --- a/src/py/flwr/server/strategy/multikrum_test.py +++ b/src/py/flwr/server/strategy/multikrum_test.py @@ -30,7 +30,7 @@ parameters_to_ndarrays, ) from flwr.server.client_proxy import ClientProxy -from flwr.server.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy +from flwr.server.superlink.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy from .krum import Krum diff --git a/src/py/flwr/server/strategy/qfedavg.py b/src/py/flwr/server/strategy/qfedavg.py index 94a67fbcbfae..758e8e608e9f 100644 --- a/src/py/flwr/server/strategy/qfedavg.py +++ b/src/py/flwr/server/strategy/qfedavg.py @@ -185,7 +185,7 @@ def norm_grad(grad_list: NDArrays) -> float: hs_ffl = [] if self.pre_weights is None: - raise Exception("QffedAvg pre_weights are None in aggregate_fit") + raise AttributeError("QffedAvg pre_weights are None in aggregate_fit") weights_before = self.pre_weights eval_result = self.evaluate( diff --git a/src/py/flwr/server/superlink/__init__.py b/src/py/flwr/server/superlink/__init__.py new file mode 100644 index 000000000000..94102100de26 --- /dev/null +++ b/src/py/flwr/server/superlink/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower SuperLink.""" diff --git a/dev/publish.sh b/src/py/flwr/server/superlink/driver/__init__.py old mode 100755 new mode 100644 similarity index 84% rename from dev/publish.sh rename to src/py/flwr/server/superlink/driver/__init__.py index fb4df1694530..2bfe63e6065f --- a/dev/publish.sh +++ b/src/py/flwr/server/superlink/driver/__init__.py @@ -1,5 +1,3 @@ -#!/bin/bash - # Copyright 2022 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,8 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - -set -e -cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/../ - -python -m poetry publish +"""Flower driver service.""" diff --git a/src/py/flwr/server/superlink/driver/driver_grpc.py b/src/py/flwr/server/superlink/driver/driver_grpc.py new file mode 100644 index 000000000000..f74000bc59c4 --- /dev/null +++ b/src/py/flwr/server/superlink/driver/driver_grpc.py @@ -0,0 +1,54 @@ +# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Driver gRPC API.""" + +from logging import INFO +from typing import Optional, Tuple + +import grpc + +from flwr.common import GRPC_MAX_MESSAGE_LENGTH +from flwr.common.logger import log +from flwr.proto.driver_pb2_grpc import ( # pylint: disable=E0611 + add_DriverServicer_to_server, +) +from flwr.server.superlink.state import StateFactory + +from ..fleet.grpc_bidi.grpc_server import generic_create_grpc_server +from .driver_servicer import DriverServicer + + +def run_driver_api_grpc( + address: str, + state_factory: StateFactory, + certificates: Optional[Tuple[bytes, bytes, bytes]], +) -> grpc.Server: + """Run Driver API (gRPC, request-response).""" + # Create Driver API gRPC server + driver_servicer: grpc.Server = DriverServicer( + state_factory=state_factory, + ) + driver_add_servicer_to_server_fn = add_DriverServicer_to_server + driver_grpc_server = generic_create_grpc_server( + servicer_and_add_fn=(driver_servicer, driver_add_servicer_to_server_fn), + server_address=address, + max_message_length=GRPC_MAX_MESSAGE_LENGTH, + certificates=certificates, + ) + + log(INFO, "Flower ECE: Starting Driver API (gRPC-rere) on %s", address) + driver_grpc_server.start() + + return driver_grpc_server diff --git a/src/py/flwr/server/driver/driver_servicer.py b/src/py/flwr/server/superlink/driver/driver_servicer.py similarity index 79% rename from src/py/flwr/server/driver/driver_servicer.py rename to src/py/flwr/server/superlink/driver/driver_servicer.py index f96b3b1262ac..59e51ef52d8e 100644 --- a/src/py/flwr/server/driver/driver_servicer.py +++ b/src/py/flwr/server/superlink/driver/driver_servicer.py @@ -15,17 +15,17 @@ """Driver API servicer.""" -from logging import INFO +from logging import DEBUG, INFO from typing import List, Optional, Set from uuid import UUID import grpc from flwr.common.logger import log -from flwr.proto import driver_pb2_grpc -from flwr.proto.driver_pb2 import ( - CreateWorkloadRequest, - CreateWorkloadResponse, +from flwr.proto import driver_pb2_grpc # pylint: disable=E0611 +from flwr.proto.driver_pb2 import ( # pylint: disable=E0611 + CreateRunRequest, + CreateRunResponse, GetNodesRequest, GetNodesResponse, PullTaskResRequest, @@ -33,9 +33,9 @@ PushTaskInsRequest, PushTaskInsResponse, ) -from flwr.proto.node_pb2 import Node -from flwr.proto.task_pb2 import TaskRes -from flwr.server.state import State, StateFactory +from flwr.proto.node_pb2 import Node # pylint: disable=E0611 +from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611 +from flwr.server.superlink.state import State, StateFactory from flwr.server.utils.validator import validate_task_ins_or_res @@ -49,28 +49,28 @@ def GetNodes( self, request: GetNodesRequest, context: grpc.ServicerContext ) -> GetNodesResponse: """Get available nodes.""" - log(INFO, "DriverServicer.GetNodes") + log(DEBUG, "DriverServicer.GetNodes") state: State = self.state_factory.state() - all_ids: Set[int] = state.get_nodes(request.workload_id) + all_ids: Set[int] = state.get_nodes(request.run_id) nodes: List[Node] = [ Node(node_id=node_id, anonymous=False) for node_id in all_ids ] return GetNodesResponse(nodes=nodes) - def CreateWorkload( - self, request: CreateWorkloadRequest, context: grpc.ServicerContext - ) -> CreateWorkloadResponse: - """Create workload ID.""" - log(INFO, "DriverServicer.CreateWorkload") + def CreateRun( + self, request: CreateRunRequest, context: grpc.ServicerContext + ) -> CreateRunResponse: + """Create run ID.""" + log(INFO, "DriverServicer.CreateRun") state: State = self.state_factory.state() - workload_id = state.create_workload() - return CreateWorkloadResponse(workload_id=workload_id) + run_id = state.create_run() + return CreateRunResponse(run_id=run_id) def PushTaskIns( self, request: PushTaskInsRequest, context: grpc.ServicerContext ) -> PushTaskInsResponse: """Push a set of TaskIns.""" - log(INFO, "DriverServicer.PushTaskIns") + log(DEBUG, "DriverServicer.PushTaskIns") # Validate request _raise_if(len(request.task_ins_list) == 0, "`task_ins_list` must not be empty") @@ -95,7 +95,7 @@ def PullTaskRes( self, request: PullTaskResRequest, context: grpc.ServicerContext ) -> PullTaskResResponse: """Pull a set of TaskRes.""" - log(INFO, "DriverServicer.PullTaskRes") + log(DEBUG, "DriverServicer.PullTaskRes") # Convert each task_id str to UUID task_ids: Set[UUID] = {UUID(task_id) for task_id in request.task_ids} @@ -105,7 +105,7 @@ def PullTaskRes( # Register callback def on_rpc_done() -> None: - log(INFO, "DriverServicer.PullTaskRes callback: delete TaskIns/TaskRes") + log(DEBUG, "DriverServicer.PullTaskRes callback: delete TaskIns/TaskRes") if context.is_active(): return diff --git a/src/py/flwr/server/driver/driver_servicer_test.py b/src/py/flwr/server/superlink/driver/driver_servicer_test.py similarity index 95% rename from src/py/flwr/server/driver/driver_servicer_test.py rename to src/py/flwr/server/superlink/driver/driver_servicer_test.py index c432c026a632..99f7cc007a89 100644 --- a/src/py/flwr/server/driver/driver_servicer_test.py +++ b/src/py/flwr/server/superlink/driver/driver_servicer_test.py @@ -15,7 +15,7 @@ """DriverServicer tests.""" -from flwr.server.driver.driver_servicer import _raise_if +from flwr.server.superlink.driver.driver_servicer import _raise_if # pylint: disable=broad-except diff --git a/src/py/flwr/server/fleet/__init__.py b/src/py/flwr/server/superlink/fleet/__init__.py similarity index 100% rename from src/py/flwr/server/fleet/__init__.py rename to src/py/flwr/server/superlink/fleet/__init__.py diff --git a/src/py/flwr/server/fleet/grpc_bidi/__init__.py b/src/py/flwr/server/superlink/fleet/grpc_bidi/__init__.py similarity index 100% rename from src/py/flwr/server/fleet/grpc_bidi/__init__.py rename to src/py/flwr/server/superlink/fleet/grpc_bidi/__init__.py diff --git a/src/py/flwr/server/fleet/grpc_bidi/flower_service_servicer.py b/src/py/flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py similarity index 88% rename from src/py/flwr/server/fleet/grpc_bidi/flower_service_servicer.py rename to src/py/flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py index 1f7a8e9259fc..6f94ea844e38 100644 --- a/src/py/flwr/server/fleet/grpc_bidi/flower_service_servicer.py +++ b/src/py/flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py @@ -18,17 +18,24 @@ - https://github.com/grpc/grpc/blob/master/doc/statuscodes.md """ - +import uuid from typing import Callable, Iterator import grpc from iterators import TimeoutIterator -from flwr.proto import transport_pb2_grpc -from flwr.proto.transport_pb2 import ClientMessage, ServerMessage +from flwr.proto import transport_pb2_grpc # pylint: disable=E0611 +from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 + ClientMessage, + ServerMessage, +) from flwr.server.client_manager import ClientManager -from flwr.server.fleet.grpc_bidi.grpc_bridge import GrpcBridge, InsWrapper, ResWrapper -from flwr.server.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy +from flwr.server.superlink.fleet.grpc_bidi.grpc_bridge import ( + GrpcBridge, + InsWrapper, + ResWrapper, +) +from flwr.server.superlink.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy def default_bridge_factory() -> GrpcBridge: @@ -88,9 +95,12 @@ def Join( # pylint: disable=invalid-name wrapping the actual message - The `Join` method is (pretty much) unaware of the protocol """ - peer: str = context.peer() + # When running Flower behind a proxy, the peer can be the same for + # different clients, so instead of `cid: str = context.peer()` we + # use a `UUID4` that is unique. + cid: str = uuid.uuid4().hex bridge = self.grpc_bridge_factory() - client_proxy = self.client_proxy_factory(peer, bridge) + client_proxy = self.client_proxy_factory(cid, bridge) is_success = register_client_proxy(self.client_manager, client_proxy, context) if is_success: diff --git a/src/py/flwr/server/fleet/grpc_bidi/flower_service_servicer_test.py b/src/py/flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer_test.py similarity index 89% rename from src/py/flwr/server/fleet/grpc_bidi/flower_service_servicer_test.py rename to src/py/flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer_test.py index 64140ed274c9..bd93554a6a32 100644 --- a/src/py/flwr/server/fleet/grpc_bidi/flower_service_servicer_test.py +++ b/src/py/flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer_test.py @@ -16,18 +16,23 @@ import unittest +import uuid from unittest.mock import MagicMock, call -from flwr.proto.transport_pb2 import ClientMessage, ServerMessage -from flwr.server.fleet.grpc_bidi.flower_service_servicer import ( +from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 + ClientMessage, + ServerMessage, +) +from flwr.server.superlink.fleet.grpc_bidi.flower_service_servicer import ( FlowerServiceServicer, register_client_proxy, ) -from flwr.server.fleet.grpc_bidi.grpc_bridge import InsWrapper, ResWrapper +from flwr.server.superlink.fleet.grpc_bidi.grpc_bridge import InsWrapper, ResWrapper CLIENT_MESSAGE = ClientMessage() SERVER_MESSAGE = ServerMessage() -CLIENT_CID = "some_client_cid" + +CID: str = uuid.uuid4().hex class FlowerServiceServicerTestCase(unittest.TestCase): @@ -39,7 +44,6 @@ def setUp(self) -> None: """Create mocks for tests.""" # Mock for the gRPC context argument self.context_mock = MagicMock() - self.context_mock.peer.return_value = CLIENT_CID # Define client_messages to be processed by FlowerServiceServicer instance self.client_messages = [CLIENT_MESSAGE for _ in range(5)] @@ -67,7 +71,7 @@ def setUp(self) -> None: # Create a GrpcClientProxy mock which we will use to test if correct # methods where called and client_messages are getting passed to it self.grpc_client_proxy_mock = MagicMock() - self.grpc_client_proxy_mock.cid = CLIENT_CID + self.grpc_client_proxy_mock.cid = CID self.client_proxy_factory_mock = MagicMock() self.client_proxy_factory_mock.return_value = self.grpc_client_proxy_mock @@ -124,11 +128,7 @@ def test_join(self) -> None: num_server_messages += 1 assert len(self.client_messages) == num_server_messages - assert self.grpc_client_proxy_mock.cid == CLIENT_CID - - self.client_proxy_factory_mock.assert_called_once_with( - CLIENT_CID, self.grpc_bridge_mock - ) + assert self.grpc_client_proxy_mock.cid == CID # Check if the client was registered with the client_manager self.client_manager_mock.register.assert_called_once_with( diff --git a/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge.py b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py similarity index 93% rename from src/py/flwr/server/fleet/grpc_bidi/grpc_bridge.py rename to src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py index 6ae38ea3d805..d5b4a915c609 100644 --- a/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge.py +++ b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py @@ -20,7 +20,10 @@ from threading import Condition from typing import Iterator, Optional -from flwr.proto.transport_pb2 import ClientMessage, ServerMessage +from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 + ClientMessage, + ServerMessage, +) @dataclass @@ -113,7 +116,7 @@ def _transition(self, next_status: Status) -> None: ): self._status = next_status else: - raise Exception(f"Invalid transition: {self._status} to {next_status}") + raise ValueError(f"Invalid transition: {self._status} to {next_status}") self._cv.notify_all() @@ -129,7 +132,7 @@ def request(self, ins_wrapper: InsWrapper) -> ResWrapper: self._raise_if_closed() if self._status != Status.AWAITING_INS_WRAPPER: - raise Exception("This should not happen") + raise ValueError("This should not happen") self._ins_wrapper = ins_wrapper # Write self._transition(Status.INS_WRAPPER_AVAILABLE) @@ -146,7 +149,7 @@ def request(self, ins_wrapper: InsWrapper) -> ResWrapper: self._transition(Status.AWAITING_INS_WRAPPER) if res_wrapper is None: - raise Exception("ResWrapper can not be None") + raise ValueError("ResWrapper can not be None") return res_wrapper @@ -170,7 +173,7 @@ def ins_wrapper_iterator(self) -> Iterator[InsWrapper]: self._transition(Status.AWAITING_RES_WRAPPER) if ins_wrapper is None: - raise Exception("InsWrapper can not be None") + raise ValueError("InsWrapper can not be None") yield ins_wrapper @@ -180,7 +183,7 @@ def set_res_wrapper(self, res_wrapper: ResWrapper) -> None: self._raise_if_closed() if self._status != Status.AWAITING_RES_WRAPPER: - raise Exception("This should not happen") + raise ValueError("This should not happen") self._res_wrapper = res_wrapper # Write self._transition(Status.RES_WRAPPER_AVAILABLE) diff --git a/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge_test.py b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_bridge_test.py similarity index 95% rename from src/py/flwr/server/fleet/grpc_bidi/grpc_bridge_test.py rename to src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_bridge_test.py index 18a2144072ed..f7c236acd7a1 100644 --- a/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge_test.py +++ b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_bridge_test.py @@ -19,8 +19,11 @@ from threading import Thread from typing import List, Union -from flwr.proto.transport_pb2 import ClientMessage, ServerMessage -from flwr.server.fleet.grpc_bidi.grpc_bridge import ( +from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 + ClientMessage, + ServerMessage, +) +from flwr.server.superlink.fleet.grpc_bidi.grpc_bridge import ( GrpcBridge, GrpcBridgeClosed, InsWrapper, @@ -70,6 +73,7 @@ def test_workflow_successful() -> None: _ = next(ins_wrapper_iterator) bridge.set_res_wrapper(ResWrapper(client_message=ClientMessage())) except Exception as exception: + # pylint: disable-next=broad-exception-raised raise Exception from exception # Wait until worker_thread is finished diff --git a/src/py/flwr/server/fleet/grpc_bidi/grpc_client_proxy.py b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py similarity index 92% rename from src/py/flwr/server/fleet/grpc_bidi/grpc_client_proxy.py rename to src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py index b9bc7330db31..ac62ad014950 100644 --- a/src/py/flwr/server/fleet/grpc_bidi/grpc_client_proxy.py +++ b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py @@ -19,9 +19,16 @@ from flwr import common from flwr.common import serde -from flwr.proto.transport_pb2 import ClientMessage, ServerMessage +from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 + ClientMessage, + ServerMessage, +) from flwr.server.client_proxy import ClientProxy -from flwr.server.fleet.grpc_bidi.grpc_bridge import GrpcBridge, InsWrapper, ResWrapper +from flwr.server.superlink.fleet.grpc_bidi.grpc_bridge import ( + GrpcBridge, + InsWrapper, + ResWrapper, +) class GrpcClientProxy(ClientProxy): @@ -39,6 +46,7 @@ def get_properties( self, ins: common.GetPropertiesIns, timeout: Optional[float], + group_id: Optional[int], ) -> common.GetPropertiesRes: """Request client's set of internal properties.""" get_properties_msg = serde.get_properties_ins_to_proto(ins) @@ -58,6 +66,7 @@ def get_parameters( self, ins: common.GetParametersIns, timeout: Optional[float], + group_id: Optional[int], ) -> common.GetParametersRes: """Return the current local model parameters.""" get_parameters_msg = serde.get_parameters_ins_to_proto(ins) @@ -77,6 +86,7 @@ def fit( self, ins: common.FitIns, timeout: Optional[float], + group_id: Optional[int], ) -> common.FitRes: """Refine the provided parameters using the locally held dataset.""" fit_ins_msg = serde.fit_ins_to_proto(ins) @@ -95,6 +105,7 @@ def evaluate( self, ins: common.EvaluateIns, timeout: Optional[float], + group_id: Optional[int], ) -> common.EvaluateRes: """Evaluate the provided parameters using the locally held dataset.""" evaluate_msg = serde.evaluate_ins_to_proto(ins) @@ -112,6 +123,7 @@ def reconnect( self, ins: common.ReconnectIns, timeout: Optional[float], + group_id: Optional[int], ) -> common.DisconnectRes: """Disconnect and (optionally) reconnect later.""" reconnect_ins_msg = serde.reconnect_ins_to_proto(ins) diff --git a/src/py/flwr/server/fleet/grpc_bidi/grpc_client_proxy_test.py b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy_test.py similarity index 90% rename from src/py/flwr/server/fleet/grpc_bidi/grpc_client_proxy_test.py rename to src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy_test.py index 329f29b3f616..e7077dfd39ae 100644 --- a/src/py/flwr/server/fleet/grpc_bidi/grpc_client_proxy_test.py +++ b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy_test.py @@ -22,9 +22,13 @@ import flwr from flwr.common.typing import Config, GetParametersIns -from flwr.proto.transport_pb2 import ClientMessage, Parameters, Scalar -from flwr.server.fleet.grpc_bidi.grpc_bridge import ResWrapper -from flwr.server.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy +from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 + ClientMessage, + Parameters, + Scalar, +) +from flwr.server.superlink.fleet.grpc_bidi.grpc_bridge import ResWrapper +from flwr.server.superlink.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy MESSAGE_PARAMETERS = Parameters(tensors=[], tensor_type="np") MESSAGE_FIT_RES = ClientMessage( @@ -67,7 +71,7 @@ def test_get_parameters(self) -> None: # Execute value: flwr.common.GetParametersRes = client.get_parameters( - ins=get_parameters_ins, timeout=None + ins=get_parameters_ins, timeout=None, group_id=0 ) # Assert @@ -84,7 +88,7 @@ def test_fit(self) -> None: ins: flwr.common.FitIns = flwr.common.FitIns(parameters, {}) # Execute - fit_res = client.fit(ins=ins, timeout=None) + fit_res = client.fit(ins=ins, timeout=None, group_id=0) # Assert assert fit_res.parameters.tensor_type == "np" @@ -102,7 +106,7 @@ def test_evaluate(self) -> None: evaluate_ins: flwr.common.EvaluateIns = flwr.common.EvaluateIns(parameters, {}) # Execute - evaluate_res = client.evaluate(evaluate_ins, timeout=None) + evaluate_res = client.evaluate(evaluate_ins, timeout=None, group_id=1) # Assert assert (0, 0.0) == ( @@ -123,7 +127,9 @@ def test_get_properties(self) -> None: ) # Execute - value: flwr.common.GetPropertiesRes = client.get_properties(ins, timeout=None) + value: flwr.common.GetPropertiesRes = client.get_properties( + ins, timeout=None, group_id=0 + ) # Assert assert value.properties["tensor_type"] == "numpy.ndarray" diff --git a/src/py/flwr/server/fleet/grpc_bidi/grpc_server.py b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server.py similarity index 96% rename from src/py/flwr/server/fleet/grpc_bidi/grpc_server.py rename to src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server.py index fc81e8eb8f4c..82f049844bd6 100644 --- a/src/py/flwr/server/fleet/grpc_bidi/grpc_server.py +++ b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server.py @@ -24,11 +24,15 @@ from flwr.common import GRPC_MAX_MESSAGE_LENGTH from flwr.common.logger import log -from flwr.proto.transport_pb2_grpc import add_FlowerServiceServicer_to_server +from flwr.proto.transport_pb2_grpc import ( # pylint: disable=E0611 + add_FlowerServiceServicer_to_server, +) from flwr.server.client_manager import ClientManager -from flwr.server.driver.driver_servicer import DriverServicer -from flwr.server.fleet.grpc_bidi.flower_service_servicer import FlowerServiceServicer -from flwr.server.fleet.grpc_rere.fleet_servicer import FleetServicer +from flwr.server.superlink.driver.driver_servicer import DriverServicer +from flwr.server.superlink.fleet.grpc_bidi.flower_service_servicer import ( + FlowerServiceServicer, +) +from flwr.server.superlink.fleet.grpc_rere.fleet_servicer import FleetServicer INVALID_CERTIFICATES_ERR_MSG = """ When setting any of root_certificate, certificate, or private_key, diff --git a/src/py/flwr/server/fleet/grpc_bidi/grpc_server_test.py b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server_test.py similarity index 96% rename from src/py/flwr/server/fleet/grpc_bidi/grpc_server_test.py rename to src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server_test.py index 4cd093d6ab0f..8afa37515950 100644 --- a/src/py/flwr/server/fleet/grpc_bidi/grpc_server_test.py +++ b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server_test.py @@ -23,12 +23,12 @@ from typing import Tuple, cast from flwr.server.client_manager import SimpleClientManager -from flwr.server.fleet.grpc_bidi.grpc_server import ( +from flwr.server.superlink.fleet.grpc_bidi.grpc_server import ( start_grpc_server, valid_certificates, ) -root_dir = dirname(abspath(join(__file__, "../../../../../.."))) +root_dir = dirname(abspath(join(__file__, "../../../../../../.."))) def load_certificates() -> Tuple[str, str, str]: diff --git a/src/py/flwr/server/fleet/grpc_rere/__init__.py b/src/py/flwr/server/superlink/fleet/grpc_rere/__init__.py similarity index 100% rename from src/py/flwr/server/fleet/grpc_rere/__init__.py rename to src/py/flwr/server/superlink/fleet/grpc_rere/__init__.py diff --git a/src/py/flwr/server/fleet/grpc_rere/fleet_servicer.py b/src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py similarity index 80% rename from src/py/flwr/server/fleet/grpc_rere/fleet_servicer.py rename to src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py index 022470cffe8a..278474477379 100644 --- a/src/py/flwr/server/fleet/grpc_rere/fleet_servicer.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py @@ -20,8 +20,8 @@ import grpc from flwr.common.logger import log -from flwr.proto import fleet_pb2_grpc -from flwr.proto.fleet_pb2 import ( +from flwr.proto import fleet_pb2_grpc # pylint: disable=E0611 +from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, CreateNodeResponse, DeleteNodeRequest, @@ -31,15 +31,15 @@ PushTaskResRequest, PushTaskResResponse, ) -from flwr.server.fleet.message_handler import message_handler -from flwr.server.state import State +from flwr.server.superlink.fleet.message_handler import message_handler +from flwr.server.superlink.state import StateFactory class FleetServicer(fleet_pb2_grpc.FleetServicer): """Fleet API servicer.""" - def __init__(self, state: State) -> None: - self.state = state + def __init__(self, state_factory: StateFactory) -> None: + self.state_factory = state_factory def CreateNode( self, request: CreateNodeRequest, context: grpc.ServicerContext @@ -48,7 +48,7 @@ def CreateNode( log(INFO, "FleetServicer.CreateNode") return message_handler.create_node( request=request, - state=self.state, + state=self.state_factory.state(), ) def DeleteNode( @@ -58,7 +58,7 @@ def DeleteNode( log(INFO, "FleetServicer.DeleteNode") return message_handler.delete_node( request=request, - state=self.state, + state=self.state_factory.state(), ) def PullTaskIns( @@ -68,7 +68,7 @@ def PullTaskIns( log(INFO, "FleetServicer.PullTaskIns") return message_handler.pull_task_ins( request=request, - state=self.state, + state=self.state_factory.state(), ) def PushTaskRes( @@ -78,5 +78,5 @@ def PushTaskRes( log(INFO, "FleetServicer.PushTaskRes") return message_handler.push_task_res( request=request, - state=self.state, + state=self.state_factory.state(), ) diff --git a/src/py/flwr/server/fleet/message_handler/__init__.py b/src/py/flwr/server/superlink/fleet/message_handler/__init__.py similarity index 100% rename from src/py/flwr/server/fleet/message_handler/__init__.py rename to src/py/flwr/server/superlink/fleet/message_handler/__init__.py diff --git a/src/py/flwr/server/fleet/message_handler/message_handler.py b/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py similarity index 89% rename from src/py/flwr/server/fleet/message_handler/message_handler.py rename to src/py/flwr/server/superlink/fleet/message_handler/message_handler.py index 71876386f059..c99a7854d53a 100644 --- a/src/py/flwr/server/fleet/message_handler/message_handler.py +++ b/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py @@ -18,7 +18,7 @@ from typing import List, Optional from uuid import UUID -from flwr.proto.fleet_pb2 import ( +from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, CreateNodeResponse, DeleteNodeRequest, @@ -29,9 +29,9 @@ PushTaskResResponse, Reconnect, ) -from flwr.proto.node_pb2 import Node -from flwr.proto.task_pb2 import TaskIns, TaskRes -from flwr.server.state import State +from flwr.proto.node_pb2 import Node # pylint: disable=E0611 +from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 +from flwr.server.superlink.state import State def create_node( @@ -47,7 +47,7 @@ def create_node( def delete_node(request: DeleteNodeRequest, state: State) -> DeleteNodeResponse: """.""" # Validate node_id - if request.node.anonymous or request.node.node_id <= 0: + if request.node.anonymous or request.node.node_id == 0: return DeleteNodeResponse() # Update state diff --git a/src/py/flwr/server/fleet/message_handler/message_handler_test.py b/src/py/flwr/server/superlink/fleet/message_handler/message_handler_test.py similarity index 94% rename from src/py/flwr/server/fleet/message_handler/message_handler_test.py rename to src/py/flwr/server/superlink/fleet/message_handler/message_handler_test.py index 25fd822492f2..c135f6fb7b61 100644 --- a/src/py/flwr/server/fleet/message_handler/message_handler_test.py +++ b/src/py/flwr/server/superlink/fleet/message_handler/message_handler_test.py @@ -17,14 +17,14 @@ from unittest.mock import MagicMock -from flwr.proto.fleet_pb2 import ( +from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, DeleteNodeRequest, PullTaskInsRequest, PushTaskResRequest, ) -from flwr.proto.node_pb2 import Node -from flwr.proto.task_pb2 import Task, TaskRes +from flwr.proto.node_pb2 import Node # pylint: disable=E0611 +from flwr.proto.task_pb2 import Task, TaskRes # pylint: disable=E0611 from .message_handler import create_node, delete_node, pull_task_ins, push_task_res @@ -109,7 +109,7 @@ def test_push_task_res() -> None: TaskRes( task_id="", group_id="", - workload_id=0, + run_id=0, task=Task(), ), ], diff --git a/src/py/flwr/server/fleet/rest_rere/__init__.py b/src/py/flwr/server/superlink/fleet/rest_rere/__init__.py similarity index 100% rename from src/py/flwr/server/fleet/rest_rere/__init__.py rename to src/py/flwr/server/superlink/fleet/rest_rere/__init__.py diff --git a/src/py/flwr/server/fleet/rest_rere/rest_api.py b/src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py similarity index 96% rename from src/py/flwr/server/fleet/rest_rere/rest_api.py rename to src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py index cd1e47f24f00..b022b34c68c8 100644 --- a/src/py/flwr/server/fleet/rest_rere/rest_api.py +++ b/src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py @@ -18,14 +18,14 @@ import sys from flwr.common.constant import MISSING_EXTRA_REST -from flwr.proto.fleet_pb2 import ( +from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, DeleteNodeRequest, PullTaskInsRequest, PushTaskResRequest, ) -from flwr.server.fleet.message_handler import message_handler -from flwr.server.state import State +from flwr.server.superlink.fleet.message_handler import message_handler +from flwr.server.superlink.state import State try: from starlette.applications import Starlette diff --git a/src/py/flwr/server/superlink/fleet/vce/__init__.py b/src/py/flwr/server/superlink/fleet/vce/__init__.py new file mode 100644 index 000000000000..57d39688b527 --- /dev/null +++ b/src/py/flwr/server/superlink/fleet/vce/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Fleet Simulation Engine side.""" + +from .vce_api import start_vce + +__all__ = [ + "start_vce", +] diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/__init__.py b/src/py/flwr/server/superlink/fleet/vce/backend/__init__.py new file mode 100644 index 000000000000..d751cf4bcae1 --- /dev/null +++ b/src/py/flwr/server/superlink/fleet/vce/backend/__init__.py @@ -0,0 +1,48 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Simulation Engine Backends.""" + +import importlib +from typing import Dict, Type + +from .backend import Backend, BackendConfig + +is_ray_installed = importlib.util.find_spec("ray") is not None + +# Mapping of supported backends +supported_backends: Dict[str, Type[Backend]] = {} + +# To log backend-specific error message when chosen backend isn't available +error_messages_backends: Dict[str, str] = {} + +if is_ray_installed: + from .raybackend import RayBackend + + supported_backends["ray"] = RayBackend +else: + error_messages_backends[ + "ray" + ] = """Unable to import module `ray`. + + To install the necessary dependencies, install `flwr` with the `simulation` extra: + + pip install -U flwr["simulation"] + """ + + +__all__ = [ + "Backend", + "BackendConfig", +] diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/backend.py b/src/py/flwr/server/superlink/fleet/vce/backend/backend.py new file mode 100644 index 000000000000..1d5e3a6a51ad --- /dev/null +++ b/src/py/flwr/server/superlink/fleet/vce/backend/backend.py @@ -0,0 +1,67 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Generic Backend class for Fleet API using the Simulation Engine.""" + + +from abc import ABC, abstractmethod +from typing import Callable, Dict, Tuple + +from flwr.client.client_app import ClientApp +from flwr.common.context import Context +from flwr.common.message import Message +from flwr.common.typing import ConfigsRecordValues + +BackendConfig = Dict[str, Dict[str, ConfigsRecordValues]] + + +class Backend(ABC): + """Abstract base class for a Simulation Engine Backend.""" + + def __init__(self, backend_config: BackendConfig, work_dir: str) -> None: + """Construct a backend.""" + + @abstractmethod + async def build(self) -> None: + """Build backend asynchronously. + + Different components need to be in place before workers in a backend are ready + to accept jobs. When this method finishes executing, the backend should be fully + ready to run jobs. + """ + + @property + def num_workers(self) -> int: + """Return number of workers in the backend. + + This is the number of TaskIns that can be processed concurrently. + """ + return 0 + + @abstractmethod + def is_worker_idle(self) -> bool: + """Report whether a backend worker is idle and can therefore run a ClientApp.""" + + @abstractmethod + async def terminate(self) -> None: + """Terminate backend.""" + + @abstractmethod + async def process_message( + self, + app: Callable[[], ClientApp], + message: Message, + context: Context, + ) -> Tuple[Message, Context]: + """Submit a job to the backend.""" diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py new file mode 100644 index 000000000000..8ef0d54622ae --- /dev/null +++ b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py @@ -0,0 +1,175 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Ray backend for the Fleet API using the Simulation Engine.""" + +import pathlib +from logging import ERROR, INFO +from typing import Callable, Dict, List, Tuple, Union + +import ray + +from flwr.client.client_app import ClientApp, LoadClientAppError +from flwr.common.context import Context +from flwr.common.logger import log +from flwr.common.message import Message +from flwr.simulation.ray_transport.ray_actor import ( + BasicActorPool, + ClientAppActor, + init_ray, +) +from flwr.simulation.ray_transport.utils import enable_tf_gpu_growth + +from .backend import Backend, BackendConfig + +ClientResourcesDict = Dict[str, Union[int, float]] + + +class RayBackend(Backend): + """A backend that submits jobs to a `BasicActorPool`.""" + + def __init__( + self, + backend_config: BackendConfig, + work_dir: str, + ) -> None: + """Prepare RayBackend by initialising Ray and creating the ActorPool.""" + log(INFO, "Initialising: %s", self.__class__.__name__) + log(INFO, "Backend config: %s", backend_config) + + if not pathlib.Path(work_dir).exists(): + raise ValueError(f"Specified work_dir {work_dir} does not exist.") + + # Init ray and append working dir if needed + runtime_env = ( + self._configure_runtime_env(work_dir=work_dir) if work_dir else None + ) + init_ray(runtime_env=runtime_env) + + # Validate client resources + self.client_resources_key = "client_resources" + + # Create actor pool + use_tf = backend_config.get("tensorflow", False) + actor_kwargs = {"on_actor_init_fn": enable_tf_gpu_growth} if use_tf else {} + + client_resources = self._validate_client_resources(config=backend_config) + self.pool = BasicActorPool( + actor_type=ClientAppActor, + client_resources=client_resources, + actor_kwargs=actor_kwargs, + ) + + def _configure_runtime_env(self, work_dir: str) -> Dict[str, Union[str, List[str]]]: + """Return list of files/subdirectories to exclude relative to work_dir. + + Without this, Ray will push everything to the Ray Cluster. + """ + runtime_env: Dict[str, Union[str, List[str]]] = {"working_dir": work_dir} + + excludes = [] + path = pathlib.Path(work_dir) + for p in path.rglob("*"): + # Exclude files need to be relative to the working_dir + if p.is_file() and not str(p).endswith(".py"): + excludes.append(str(p.relative_to(path))) + runtime_env["excludes"] = excludes + + return runtime_env + + def _validate_client_resources(self, config: BackendConfig) -> ClientResourcesDict: + client_resources_config = config.get(self.client_resources_key) + client_resources: ClientResourcesDict = {} + valid_types = (int, float) + if client_resources_config: + for k, v in client_resources_config.items(): + if not isinstance(k, str): + raise ValueError( + f"client resources keys are expected to be `str` but you used " + f"{type(k)} for `{k}`" + ) + if not isinstance(v, valid_types): + raise ValueError( + f"client resources are expected to be of type {valid_types} " + f"but found `{type(v)}` for key `{k}`", + ) + client_resources[k] = v + + else: + client_resources = {"num_cpus": 2, "num_gpus": 0.0} + log( + INFO, + "`%s` not specified in backend config. Applying default setting: %s", + self.client_resources_key, + client_resources, + ) + + return client_resources + + @property + def num_workers(self) -> int: + """Return number of actors in pool.""" + return self.pool.num_actors + + def is_worker_idle(self) -> bool: + """Report whether the pool has idle actors.""" + return self.pool.is_actor_available() + + async def build(self) -> None: + """Build pool of Ray actors that this backend will submit jobs to.""" + await self.pool.add_actors_to_pool(self.pool.actors_capacity) + log(INFO, "Constructed ActorPool with: %i actors", self.pool.num_actors) + + async def process_message( + self, + app: Callable[[], ClientApp], + message: Message, + context: Context, + ) -> Tuple[Message, Context]: + """Run ClientApp that process a given message. + + Return output message and updated context. + """ + partition_id = message.metadata.partition_id + + try: + # Submite a task to the pool + future = await self.pool.submit( + lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state), + (app, message, str(partition_id), context), + ) + + await future + + # Fetch result + ( + out_mssg, + updated_context, + ) = await self.pool.fetch_result_and_return_actor_to_pool(future) + + return out_mssg, updated_context + + except LoadClientAppError as load_ex: + log( + ERROR, + "An exception was raised when processing a message by %s", + self.__class__.__name__, + ) + raise load_ex + + async def terminate(self) -> None: + """Terminate all actors in actor pool.""" + await self.pool.terminate_all_actors() + ray.shutdown() + log(INFO, "Terminated %s", self.__class__.__name__) diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py new file mode 100644 index 000000000000..2610307bb749 --- /dev/null +++ b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py @@ -0,0 +1,216 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test for Ray backend for the Fleet API using the Simulation Engine.""" + +import asyncio +from math import pi +from pathlib import Path +from typing import Callable, Dict, Optional, Tuple, Union +from unittest import IsolatedAsyncioTestCase + +import ray + +from flwr.client import Client, NumPyClient +from flwr.client.client_app import ClientApp, LoadClientAppError +from flwr.common import ( + Config, + ConfigsRecord, + Context, + GetPropertiesIns, + Message, + MessageTypeLegacy, + Metadata, + RecordSet, + Scalar, +) +from flwr.common.object_ref import load_app +from flwr.common.recordset_compat import getpropertiesins_to_recordset +from flwr.server.superlink.fleet.vce.backend.raybackend import RayBackend + + +class DummyClient(NumPyClient): + """A dummy NumPyClient for tests.""" + + def get_properties(self, config: Config) -> Dict[str, Scalar]: + """Return properties by doing a simple calculation.""" + result = float(config["factor"]) * pi + + # store something in context + self.context.state.configs_records["result"] = ConfigsRecord({"result": result}) + return {"result": result} + + +def get_dummy_client(cid: str) -> Client: # pylint: disable=unused-argument + """Return a DummyClient converted to Client type.""" + return DummyClient().to_client() + + +def _load_app() -> ClientApp: + return ClientApp(client_fn=get_dummy_client) + + +client_app = ClientApp( + client_fn=get_dummy_client, +) + + +def _load_from_module(client_app_module_name: str) -> Callable[[], ClientApp]: + def _load_app() -> ClientApp: + app = load_app(client_app_module_name, LoadClientAppError) + + if not isinstance(app, ClientApp): + raise LoadClientAppError( + f"Attribute {client_app_module_name} is not of type {ClientApp}", + ) from None + + return app + + return _load_app + + +async def backend_build_process_and_termination( + backend: RayBackend, + process_args: Optional[Tuple[Callable[[], ClientApp], Message, Context]] = None, +) -> Union[Tuple[Message, Context], None]: + """Build, process job and terminate RayBackend.""" + await backend.build() + to_return = None + + if process_args: + to_return = await backend.process_message(*process_args) + + await backend.terminate() + + return to_return + + +def _create_message_and_context() -> Tuple[Message, Context, float]: + + # Construct a Message + mult_factor = 2024 + getproperties_ins = GetPropertiesIns(config={"factor": mult_factor}) + recordset = getpropertiesins_to_recordset(getproperties_ins) + message = Message( + content=recordset, + metadata=Metadata( + run_id=0, + message_id="", + group_id="", + src_node_id=0, + dst_node_id=0, + reply_to_message="", + ttl="", + message_type=MessageTypeLegacy.GET_PROPERTIES, + ), + ) + + # Construct emtpy Context + context = Context(state=RecordSet()) + + # Expected output + expected_output = pi * mult_factor + + return message, context, expected_output + + +class AsyncTestRayBackend(IsolatedAsyncioTestCase): + """A basic class that allows runnig multliple asyncio tests.""" + + async def on_cleanup(self) -> None: + """Ensure Ray has shutdown.""" + if ray.is_initialized(): + ray.shutdown() + + def test_backend_creation_and_termination(self) -> None: + """Test creation of RayBackend and its termination.""" + backend = RayBackend(backend_config={}, work_dir="") + asyncio.run( + backend_build_process_and_termination(backend=backend, process_args=None) + ) + + def test_backend_creation_submit_and_termination( + self, + client_app_loader: Callable[[], ClientApp] = _load_app, + workdir: str = "", + ) -> None: + """Test submitting a message to a given ClientApp.""" + backend = RayBackend(backend_config={}, work_dir=workdir) + + # Define ClientApp + client_app_callable = client_app_loader + + message, context, expected_output = _create_message_and_context() + + res = asyncio.run( + backend_build_process_and_termination( + backend=backend, process_args=(client_app_callable, message, context) + ) + ) + + if res is None: + raise AssertionError("This shouldn't happen") + + out_mssg, updated_context = res + + # Verify message content is as expected + content = out_mssg.content + assert ( + content.configs_records["getpropertiesres.properties"]["result"] + == expected_output + ) + + # Verify context is correct + obtained_result_in_context = updated_context.state.configs_records["result"][ + "result" + ] + assert obtained_result_in_context == expected_output + + def test_backend_creation_submit_and_termination_non_existing_client_app( + self, + ) -> None: + """Testing with ClientApp module that does not exist.""" + with self.assertRaises(LoadClientAppError): + self.test_backend_creation_submit_and_termination( + client_app_loader=_load_from_module("a_non_existing_module:app") + ) + self.addAsyncCleanup(self.on_cleanup) + + def test_backend_creation_submit_and_termination_existing_client_app( + self, + ) -> None: + """Testing with ClientApp module that exist.""" + # Resolve what should be the workdir to pass upon Backend initialisation + file_path = Path(__file__) + working_dir = Path.cwd() + rel_workdir = file_path.relative_to(working_dir) + + # Susbtract last element + rel_workdir_str = str(rel_workdir.parent) + + self.test_backend_creation_submit_and_termination( + client_app_loader=_load_from_module("raybackend_test:client_app"), + workdir=rel_workdir_str, + ) + + def test_backend_creation_submit_and_termination_existing_client_app_unsetworkdir( + self, + ) -> None: + """Testing with ClientApp module that exist but the passed workdir does not.""" + with self.assertRaises(ValueError): + self.test_backend_creation_submit_and_termination( + client_app_loader=_load_from_module("raybackend_test:client_app"), + workdir="/?&%$^#%@$!", + ) + self.addAsyncCleanup(self.on_cleanup) diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api.py b/src/py/flwr/server/superlink/fleet/vce/vce_api.py new file mode 100644 index 000000000000..a693c968d0e8 --- /dev/null +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api.py @@ -0,0 +1,331 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Fleet Simulation Engine API.""" + + +import asyncio +import json +import traceback +from logging import DEBUG, ERROR, INFO, WARN +from typing import Callable, Dict, List, Optional + +from flwr.client.client_app import ClientApp, LoadClientAppError +from flwr.client.node_state import NodeState +from flwr.common.logger import log +from flwr.common.object_ref import load_app +from flwr.common.serde import message_from_taskins, message_to_taskres +from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611 +from flwr.server.superlink.state import StateFactory + +from .backend import Backend, error_messages_backends, supported_backends + +NodeToPartitionMapping = Dict[int, int] + + +def _register_nodes( + num_nodes: int, state_factory: StateFactory +) -> NodeToPartitionMapping: + """Register nodes with the StateFactory and create node-id:partition-id mapping.""" + nodes_mapping: NodeToPartitionMapping = {} + state = state_factory.state() + for i in range(num_nodes): + node_id = state.create_node() + nodes_mapping[node_id] = i + log(INFO, "Registered %i nodes", len(nodes_mapping)) + return nodes_mapping + + +# pylint: disable=too-many-arguments,too-many-locals +async def worker( + app_fn: Callable[[], ClientApp], + queue: "asyncio.Queue[TaskIns]", + node_states: Dict[int, NodeState], + state_factory: StateFactory, + nodes_mapping: NodeToPartitionMapping, + backend: Backend, +) -> None: + """Get TaskIns from queue and pass it to an actor in the pool to execute it.""" + state = state_factory.state() + while True: + try: + task_ins: TaskIns = await queue.get() + node_id = task_ins.task.consumer.node_id + + # Register and retrieve runstate + node_states[node_id].register_context(run_id=task_ins.run_id) + context = node_states[node_id].retrieve_context(run_id=task_ins.run_id) + + # Convert TaskIns to Message + message = message_from_taskins(task_ins) + # Set partition_id + message.metadata.partition_id = nodes_mapping[node_id] + + # Let backend process message + out_mssg, updated_context = await backend.process_message( + app_fn, message, context + ) + + # Update Context + node_states[node_id].update_context( + task_ins.run_id, context=updated_context + ) + + # Convert to TaskRes + task_res = message_to_taskres(out_mssg) + # Store TaskRes in state + state.store_task_res(task_res) + + except asyncio.CancelledError as e: + log(DEBUG, "Async worker: %s", e) + break + + except LoadClientAppError as app_ex: + log(ERROR, "Async worker: %s", app_ex) + log(ERROR, traceback.format_exc()) + raise + + except Exception as ex: # pylint: disable=broad-exception-caught + log(ERROR, ex) + log(ERROR, traceback.format_exc()) + break + + +async def add_taskins_to_queue( + queue: "asyncio.Queue[TaskIns]", + state_factory: StateFactory, + nodes_mapping: NodeToPartitionMapping, + backend: Backend, + consumers: List["asyncio.Task[None]"], + f_stop: asyncio.Event, +) -> None: + """Retrieve TaskIns and add it to the queue.""" + state = state_factory.state() + num_initial_consumers = len(consumers) + while not f_stop.is_set(): + for node_id in nodes_mapping.keys(): + task_ins = state.get_task_ins(node_id=node_id, limit=1) + if task_ins: + await queue.put(task_ins[0]) + + # Count consumers that are running + num_active = sum(not (cc.done()) for cc in consumers) + + # Alert if number of consumers decreased by half + if num_active < num_initial_consumers // 2: + log( + WARN, + "Number of active workers has more than halved: (%i/%i active)", + num_active, + num_initial_consumers, + ) + + # Break if consumers died + if num_active == 0: + raise RuntimeError("All workers have died. Ending Simulation.") + + # Log some stats + log( + DEBUG, + "Simulation Engine stats: " + "Active workers: (%i/%i) | %s (%i workers) | Tasks in queue: %i)", + num_active, + num_initial_consumers, + backend.__class__.__name__, + backend.num_workers, + queue.qsize(), + ) + await asyncio.sleep(1.0) + log(DEBUG, "Async producer: Stopped pulling from StateFactory.") + + +async def run( + app_fn: Callable[[], ClientApp], + backend_fn: Callable[[], Backend], + nodes_mapping: NodeToPartitionMapping, + state_factory: StateFactory, + node_states: Dict[int, NodeState], + f_stop: asyncio.Event, +) -> None: + """Run the VCE async.""" + queue: "asyncio.Queue[TaskIns]" = asyncio.Queue(128) + + try: + + # Instantiate backend + backend = backend_fn() + + # Build backend + await backend.build() + + # Add workers (they submit Messages to Backend) + worker_tasks = [ + asyncio.create_task( + worker( + app_fn, queue, node_states, state_factory, nodes_mapping, backend + ) + ) + for _ in range(backend.num_workers) + ] + # Create producer (adds TaskIns into Queue) + producer = asyncio.create_task( + add_taskins_to_queue( + queue, state_factory, nodes_mapping, backend, worker_tasks, f_stop + ) + ) + + # Wait for producer to finish + # The producer runs forever until f_stop is set or until + # all worker (consumer) coroutines are completed. Workers + # also run forever and only end if an exception is raised. + await asyncio.gather(producer) + + except Exception as ex: + + log(ERROR, "An exception occured!! %s", ex) + log(ERROR, traceback.format_exc()) + log(WARN, "Stopping Simulation Engine.") + + # Manually trigger stopping event + f_stop.set() + + # Raise exception + raise RuntimeError("Simulation Engine crashed.") from ex + + finally: + # Produced task terminated, now cancel worker tasks + for w_t in worker_tasks: + _ = w_t.cancel() + + while not all(w_t.done() for w_t in worker_tasks): + log(DEBUG, "Terminating async workers...") + await asyncio.sleep(0.5) + + await asyncio.gather(*[w_t for w_t in worker_tasks if not w_t.done()]) + + # Terminate backend + await backend.terminate() + + +# pylint: disable=too-many-arguments,unused-argument,too-many-locals +def start_vce( + backend_name: str, + backend_config_json_stream: str, + app_dir: str, + f_stop: asyncio.Event, + client_app: Optional[ClientApp] = None, + client_app_attr: Optional[str] = None, + num_supernodes: Optional[int] = None, + state_factory: Optional[StateFactory] = None, + existing_nodes_mapping: Optional[NodeToPartitionMapping] = None, +) -> None: + """Start Fleet API with the Simulation Engine.""" + if client_app_attr is not None and client_app is not None: + raise ValueError( + "Both `client_app_attr` and `client_app` are provided, " + "but only one is allowed." + ) + + if num_supernodes is not None and existing_nodes_mapping is not None: + raise ValueError( + "Both `num_supernodes` and `existing_nodes_mapping` are provided, " + "but only one is allowed." + ) + if num_supernodes is None: + if state_factory is None or existing_nodes_mapping is None: + raise ValueError( + "If not passing an existing `state_factory` and associated " + "`existing_nodes_mapping` you must supply `num_supernodes` to indicate " + "how many nodes to insert into a new StateFactory that will be created." + ) + if existing_nodes_mapping: + if state_factory is None: + raise ValueError( + "`existing_nodes_mapping` was passed, but no `state_factory` was " + "passed." + ) + log(INFO, "Using exiting NodeToPartitionMapping and StateFactory.") + # Use mapping constructed externally. This also means nodes + # have previously being registered. + nodes_mapping = existing_nodes_mapping + + if not state_factory: + log(INFO, "A StateFactory was not supplied to the SimulationEngine.") + # Create an empty in-memory state factory + state_factory = StateFactory(":flwr-in-memory-state:") + log(INFO, "Created new %s.", state_factory.__class__.__name__) + + if num_supernodes: + # Register SuperNodes + nodes_mapping = _register_nodes( + num_nodes=num_supernodes, state_factory=state_factory + ) + + # Construct mapping of NodeStates + node_states: Dict[int, NodeState] = {} + for node_id in nodes_mapping: + node_states[node_id] = NodeState() + + # Load backend config + log(INFO, "Supported backends: %s", list(supported_backends.keys())) + backend_config = json.loads(backend_config_json_stream) + + try: + backend_type = supported_backends[backend_name] + except KeyError as ex: + log( + ERROR, + "Backend `%s`, is not supported. Use any of %s or add support " + "for a new backend.", + backend_name, + list(supported_backends.keys()), + ) + if backend_name in error_messages_backends: + log(ERROR, error_messages_backends[backend_name]) + + raise ex + + def backend_fn() -> Backend: + """Instantiate a Backend.""" + return backend_type(backend_config, work_dir=app_dir) + + log(INFO, "client_app_attr = %s", client_app_attr) + + # Load ClientApp if needed + def _load() -> ClientApp: + + if client_app_attr: + app: ClientApp = load_app(client_app_attr, LoadClientAppError) + + if not isinstance(app, ClientApp): + raise LoadClientAppError( + f"Attribute {client_app_attr} is not of type {ClientApp}", + ) from None + + if client_app: + app = client_app + return app + + app_fn = _load + + asyncio.run( + run( + app_fn, + backend_fn, + nodes_mapping, + state_factory, + node_states, + f_stop, + ) + ) diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py new file mode 100644 index 000000000000..8c37399ae295 --- /dev/null +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py @@ -0,0 +1,291 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test Fleet Simulation Engine API.""" + + +import asyncio +import threading +from itertools import cycle +from json import JSONDecodeError +from math import pi +from pathlib import Path +from time import sleep +from typing import Dict, Optional, Set, Tuple +from unittest import IsolatedAsyncioTestCase +from uuid import UUID + +from flwr.common import GetPropertiesIns, Message, MessageTypeLegacy, Metadata +from flwr.common.recordset_compat import getpropertiesins_to_recordset +from flwr.common.serde import message_from_taskres, message_to_taskins +from flwr.server.superlink.fleet.vce.vce_api import ( + NodeToPartitionMapping, + _register_nodes, + start_vce, +) +from flwr.server.superlink.state import InMemoryState, StateFactory + + +def terminate_simulation(f_stop: asyncio.Event, sleep_duration: int) -> None: + """Set event to terminate Simulation Engine after `sleep_duration` seconds.""" + sleep(sleep_duration) + f_stop.set() + + +def init_state_factory_nodes_mapping( + num_nodes: int, + num_messages: int, + erroneous_message: Optional[bool] = False, +) -> Tuple[StateFactory, NodeToPartitionMapping, Dict[UUID, float]]: + """Instatiate StateFactory, register nodes and pre-insert messages in the state.""" + # Register a state and a run_id in it + run_id = 1234 + state_factory = StateFactory(":flwr-in-memory-state:") + + # Register a few nodes + nodes_mapping = _register_nodes(num_nodes=num_nodes, state_factory=state_factory) + + expected_results = register_messages_into_state( + state_factory=state_factory, + nodes_mapping=nodes_mapping, + run_id=run_id, + num_messages=num_messages, + erroneous_message=erroneous_message, + ) + return state_factory, nodes_mapping, expected_results + + +# pylint: disable=too-many-locals +def register_messages_into_state( + state_factory: StateFactory, + nodes_mapping: NodeToPartitionMapping, + run_id: int, + num_messages: int, + erroneous_message: Optional[bool] = False, +) -> Dict[UUID, float]: + """Register `num_messages` into the state factory.""" + state: InMemoryState = state_factory.state() # type: ignore + state.run_ids.add(run_id) + # Artificially add TaskIns to state so they can be processed + # by the Simulation Engine logic + nodes_cycle = cycle(nodes_mapping.keys()) # we have more messages than supernodes + task_ids: Set[UUID] = set() # so we can retrieve them later + expected_results = {} + for i in range(num_messages): + dst_node_id = next(nodes_cycle) + # Construct a Message + mult_factor = 2024 + i + getproperties_ins = GetPropertiesIns(config={"factor": mult_factor}) + recordset = getpropertiesins_to_recordset(getproperties_ins) + message = Message( + content=recordset, + metadata=Metadata( + run_id=run_id, + message_id="", + group_id="", + src_node_id=0, + dst_node_id=dst_node_id, # indicate destination node + reply_to_message="", + ttl="", + message_type=( + "a bad message" + if erroneous_message + else MessageTypeLegacy.GET_PROPERTIES + ), + ), + ) + # Convert Message to TaskIns + taskins = message_to_taskins(message) + # Instert in state + task_id = state.store_task_ins(taskins) + if task_id: + # Add to UUID set + task_ids.add(task_id) + # Store expected output for check later on + expected_results[task_id] = mult_factor * pi + + return expected_results + + +def _autoresolve_app_dir(rel_client_app_dir: str = "backend") -> str: + """Correctly resolve working directory for the app.""" + file_path = Path(__file__) + app_dir = Path.cwd() + rel_app_dir = file_path.relative_to(app_dir) + + # Susbtract lats element and append "backend/test" (wher the client module is.) + return str(rel_app_dir.parent / rel_client_app_dir) + + +# pylint: disable=too-many-arguments +def start_and_shutdown( + backend: str = "ray", + client_app_attr: str = "raybackend_test:client_app", + app_dir: str = "", + num_supernodes: Optional[int] = None, + state_factory: Optional[StateFactory] = None, + nodes_mapping: Optional[NodeToPartitionMapping] = None, + duration: int = 0, + backend_config: str = "{}", +) -> None: + """Start Simulation Engine and terminate after specified number of seconds. + + Some tests need to be terminated by triggering externally an asyncio.Event. This + is enabled whtn passing `duration`>0. + """ + f_stop = asyncio.Event() + + if duration: + + # Setup thread that will set the f_stop event, triggering the termination of all + # asyncio logic in the Simulation Engine. It will also terminate the Backend. + termination_th = threading.Thread( + target=terminate_simulation, args=(f_stop, duration) + ) + termination_th.start() + + # Resolve working directory if not passed + if not app_dir: + app_dir = _autoresolve_app_dir() + + start_vce( + num_supernodes=num_supernodes, + client_app_attr=client_app_attr, + backend_name=backend, + backend_config_json_stream=backend_config, + state_factory=state_factory, + app_dir=app_dir, + f_stop=f_stop, + existing_nodes_mapping=nodes_mapping, + ) + + if duration: + termination_th.join() + + +class AsyncTestFleetSimulationEngineRayBackend(IsolatedAsyncioTestCase): + """A basic class that enables testing asyncio functionalities.""" + + def test_erroneous_no_supernodes_client_mapping(self) -> None: + """Test with unset arguments.""" + with self.assertRaises(ValueError): + start_and_shutdown(duration=2) + + def test_erroneous_client_app_attr(self) -> None: + """Tests attempt to load a ClientApp that can't be found.""" + num_messages = 7 + num_nodes = 59 + + state_factory, nodes_mapping, _ = init_state_factory_nodes_mapping( + num_nodes=num_nodes, num_messages=num_messages + ) + with self.assertRaises(RuntimeError): + start_and_shutdown( + client_app_attr="totally_fictitious_app:client", + state_factory=state_factory, + nodes_mapping=nodes_mapping, + ) + + def test_erroneous_messages(self) -> None: + """Test handling of error in async worker (consumer). + + We register messages which will trigger an error when handling, triggering an + error. + """ + num_messages = 100 + num_nodes = 59 + + state_factory, nodes_mapping, _ = init_state_factory_nodes_mapping( + num_nodes=num_nodes, num_messages=num_messages, erroneous_message=True + ) + + with self.assertRaises(RuntimeError): + start_and_shutdown( + state_factory=state_factory, + nodes_mapping=nodes_mapping, + ) + + def test_erroneous_backend_config(self) -> None: + """Backend Config should be a JSON stream.""" + with self.assertRaises(JSONDecodeError): + start_and_shutdown(num_supernodes=50, backend_config="not a proper config") + + def test_with_nonexistent_backend(self) -> None: + """Test specifying a backend that does not exist.""" + with self.assertRaises(KeyError): + start_and_shutdown(num_supernodes=50, backend="this-backend-does-not-exist") + + def test_erroneous_arguments_num_supernodes_and_existing_mapping(self) -> None: + """Test ValueError if a node mapping is passed but also num_supernodes. + + Passing `num_supernodes` does nothing since we assume that if a node mapping + is supplied, nodes have been registered externally already. Therefore passing + `num_supernodes` might give the impression that that many nodes will be + registered. We don't do that since a mapping already exists. + """ + with self.assertRaises(ValueError): + start_and_shutdown(num_supernodes=50, nodes_mapping={0: 1}) + + def test_erroneous_arguments_existing_mapping_but_no_state_factory(self) -> None: + """Test ValueError if a node mapping is passed but no state. + + Passing a node mapping indicates that (externally) nodes have registered with a + state factory. Therefore, that state factory should be passed too. + """ + with self.assertRaises(ValueError): + start_and_shutdown(nodes_mapping={0: 1}) + + def test_start_and_shutdown(self) -> None: + """Start Simulation Engine Fleet and terminate it.""" + start_and_shutdown(num_supernodes=50, duration=10) + + # pylint: disable=too-many-locals + def test_start_and_shutdown_with_tasks_in_state(self) -> None: + """Run Simulation Engine with some TasksIns in State. + + This test creates a few nodes and submits a few messages that need to be + executed by the Backend. In order for that to happen the asyncio + producer/consumer logic must function. This also severs to evaluate a valid + ClientApp. + """ + num_messages = 229 + num_nodes = 59 + + state_factory, nodes_mapping, expected_results = ( + init_state_factory_nodes_mapping( + num_nodes=num_nodes, num_messages=num_messages + ) + ) + + # Run + start_and_shutdown( + state_factory=state_factory, nodes_mapping=nodes_mapping, duration=10 + ) + + # Get all TaskRes + state = state_factory.state() + task_ids = set(expected_results.keys()) + task_res_list = state.get_task_res(task_ids=task_ids, limit=len(task_ids)) + + # Check results by first converting to Message + for task_res in task_res_list: + + message = message_from_taskres(task_res) + + # Verify message content is as expected + content = message.content + assert ( + content.configs_records["getpropertiesres.properties"]["result"] + == expected_results[UUID(task_res.task.ancestry[0])] + ) diff --git a/src/py/flwr/server/state/__init__.py b/src/py/flwr/server/superlink/state/__init__.py similarity index 100% rename from src/py/flwr/server/state/__init__.py rename to src/py/flwr/server/superlink/state/__init__.py diff --git a/src/py/flwr/server/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py similarity index 59% rename from src/py/flwr/server/state/in_memory_state.py rename to src/py/flwr/server/superlink/state/in_memory_state.py index 384839b7461f..ac1ab158e254 100644 --- a/src/py/flwr/server/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -16,14 +16,15 @@ import os +import threading from datetime import datetime, timedelta from logging import ERROR from typing import Dict, List, Optional, Set from uuid import UUID, uuid4 from flwr.common import log, now -from flwr.proto.task_pb2 import TaskIns, TaskRes -from flwr.server.state.state import State +from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 +from flwr.server.superlink.state.state import State from flwr.server.utils import validate_task_ins_or_res @@ -32,9 +33,10 @@ class InMemoryState(State): def __init__(self) -> None: self.node_ids: Set[int] = set() - self.workload_ids: Set[int] = set() + self.run_ids: Set[int] = set() self.task_ins_store: Dict[UUID, TaskIns] = {} self.task_res_store: Dict[UUID, TaskRes] = {} + self.lock = threading.Lock() def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: """Store one TaskIns.""" @@ -43,9 +45,9 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: if any(errors): log(ERROR, errors) return None - # Validate workload_id - if task_ins.workload_id not in self.workload_ids: - log(ERROR, "`workload_id` is invalid") + # Validate run_id + if task_ins.run_id not in self.run_ids: + log(ERROR, "`run_id` is invalid") return None # Create task_id, created_at and ttl @@ -57,7 +59,8 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: task_ins.task_id = str(task_id) task_ins.task.created_at = created_at.isoformat() task_ins.task.ttl = ttl.isoformat() - self.task_ins_store[task_id] = task_ins + with self.lock: + self.task_ins_store[task_id] = task_ins # Return the new task_id return task_id @@ -71,22 +74,23 @@ def get_task_ins( # Find TaskIns for node_id that were not delivered yet task_ins_list: List[TaskIns] = [] - for _, task_ins in self.task_ins_store.items(): - # pylint: disable=too-many-boolean-expressions - if ( - node_id is not None # Not anonymous - and task_ins.task.consumer.anonymous is False - and task_ins.task.consumer.node_id == node_id - and task_ins.task.delivered_at == "" - ) or ( - node_id is None # Anonymous - and task_ins.task.consumer.anonymous is True - and task_ins.task.consumer.node_id == 0 - and task_ins.task.delivered_at == "" - ): - task_ins_list.append(task_ins) - if limit and len(task_ins_list) == limit: - break + with self.lock: + for _, task_ins in self.task_ins_store.items(): + # pylint: disable=too-many-boolean-expressions + if ( + node_id is not None # Not anonymous + and task_ins.task.consumer.anonymous is False + and task_ins.task.consumer.node_id == node_id + and task_ins.task.delivered_at == "" + ) or ( + node_id is None # Anonymous + and task_ins.task.consumer.anonymous is True + and task_ins.task.consumer.node_id == 0 + and task_ins.task.delivered_at == "" + ): + task_ins_list.append(task_ins) + if limit and len(task_ins_list) == limit: + break # Mark all of them as delivered delivered_at = now().isoformat() @@ -104,9 +108,9 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: log(ERROR, errors) return None - # Validate workload_id - if task_res.workload_id not in self.workload_ids: - log(ERROR, "`workload_id` is invalid") + # Validate run_id + if task_res.run_id not in self.run_ids: + log(ERROR, "`run_id` is invalid") return None # Create task_id, created_at and ttl @@ -118,7 +122,8 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: task_res.task_id = str(task_id) task_res.task.created_at = created_at.isoformat() task_res.task.ttl = ttl.isoformat() - self.task_res_store[task_id] = task_res + with self.lock: + self.task_res_store[task_id] = task_res # Return the new task_id return task_id @@ -128,45 +133,47 @@ def get_task_res(self, task_ids: Set[UUID], limit: Optional[int]) -> List[TaskRe if limit is not None and limit < 1: raise AssertionError("`limit` must be >= 1") - # Find TaskRes that were not delivered yet - task_res_list: List[TaskRes] = [] - for _, task_res in self.task_res_store.items(): - if ( - UUID(task_res.task.ancestry[0]) in task_ids - and task_res.task.delivered_at == "" - ): - task_res_list.append(task_res) - if limit and len(task_res_list) == limit: - break - - # Mark all of them as delivered - delivered_at = now().isoformat() - for task_res in task_res_list: - task_res.task.delivered_at = delivered_at - - # Return TaskRes - return task_res_list + with self.lock: + # Find TaskRes that were not delivered yet + task_res_list: List[TaskRes] = [] + for _, task_res in self.task_res_store.items(): + if ( + UUID(task_res.task.ancestry[0]) in task_ids + and task_res.task.delivered_at == "" + ): + task_res_list.append(task_res) + if limit and len(task_res_list) == limit: + break + + # Mark all of them as delivered + delivered_at = now().isoformat() + for task_res in task_res_list: + task_res.task.delivered_at = delivered_at + + # Return TaskRes + return task_res_list def delete_tasks(self, task_ids: Set[UUID]) -> None: """Delete all delivered TaskIns/TaskRes pairs.""" task_ins_to_be_deleted: Set[UUID] = set() task_res_to_be_deleted: Set[UUID] = set() - for task_ins_id in task_ids: - # Find the task_id of the matching task_res - for task_res_id, task_res in self.task_res_store.items(): - if UUID(task_res.task.ancestry[0]) != task_ins_id: - continue - if task_res.task.delivered_at == "": - continue + with self.lock: + for task_ins_id in task_ids: + # Find the task_id of the matching task_res + for task_res_id, task_res in self.task_res_store.items(): + if UUID(task_res.task.ancestry[0]) != task_ins_id: + continue + if task_res.task.delivered_at == "": + continue - task_ins_to_be_deleted.add(task_ins_id) - task_res_to_be_deleted.add(task_res_id) + task_ins_to_be_deleted.add(task_ins_id) + task_res_to_be_deleted.add(task_res_id) - for task_id in task_ins_to_be_deleted: - del self.task_ins_store[task_id] - for task_id in task_res_to_be_deleted: - del self.task_res_store[task_id] + for task_id in task_ins_to_be_deleted: + del self.task_ins_store[task_id] + for task_id in task_res_to_be_deleted: + del self.task_res_store[task_id] def num_task_ins(self) -> int: """Calculate the number of task_ins in store. @@ -199,25 +206,25 @@ def delete_node(self, node_id: int) -> None: raise ValueError(f"Node {node_id} not found") self.node_ids.remove(node_id) - def get_nodes(self, workload_id: int) -> Set[int]: + def get_nodes(self, run_id: int) -> Set[int]: """Return all available client nodes. Constraints ----------- - If the provided `workload_id` does not exist or has no matching nodes, + If the provided `run_id` does not exist or has no matching nodes, an empty `Set` MUST be returned. """ - if workload_id not in self.workload_ids: + if run_id not in self.run_ids: return set() return self.node_ids - def create_workload(self) -> int: - """Create one workload.""" - # Sample a random int64 as workload_id - workload_id: int = int.from_bytes(os.urandom(8), "little", signed=True) + def create_run(self) -> int: + """Create one run.""" + # Sample a random int64 as run_id + run_id: int = int.from_bytes(os.urandom(8), "little", signed=True) - if workload_id not in self.workload_ids: - self.workload_ids.add(workload_id) - return workload_id - log(ERROR, "Unexpected workload creation failure.") + if run_id not in self.run_ids: + self.run_ids.add(run_id) + return run_id + log(ERROR, "Unexpected run creation failure.") return 0 diff --git a/src/py/flwr/server/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py similarity index 87% rename from src/py/flwr/server/state/sqlite_state.py rename to src/py/flwr/server/superlink/state/sqlite_state.py index f3ff60f370e9..224c16cdf013 100644 --- a/src/py/flwr/server/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -24,9 +24,9 @@ from uuid import UUID, uuid4 from flwr.common import log, now -from flwr.proto.node_pb2 import Node -from flwr.proto.task_pb2 import Task, TaskIns, TaskRes -from flwr.proto.transport_pb2 import ClientMessage, ServerMessage +from flwr.proto.node_pb2 import Node # pylint: disable=E0611 +from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611 +from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 from flwr.server.utils.validator import validate_task_ins_or_res from .state import State @@ -37,9 +37,9 @@ ); """ -SQL_CREATE_TABLE_WORKLOAD = """ -CREATE TABLE IF NOT EXISTS workload( - workload_id INTEGER UNIQUE +SQL_CREATE_TABLE_RUN = """ +CREATE TABLE IF NOT EXISTS run( + run_id INTEGER UNIQUE ); """ @@ -47,7 +47,7 @@ CREATE TABLE IF NOT EXISTS task_ins( task_id TEXT UNIQUE, group_id TEXT, - workload_id INTEGER, + run_id INTEGER, producer_anonymous BOOLEAN, producer_node_id INTEGER, consumer_anonymous BOOLEAN, @@ -56,9 +56,9 @@ delivered_at TEXT, ttl TEXT, ancestry TEXT, - legacy_server_message BLOB, - legacy_client_message BLOB, - FOREIGN KEY(workload_id) REFERENCES workload(workload_id) + task_type TEXT, + recordset BLOB, + FOREIGN KEY(run_id) REFERENCES run(run_id) ); """ @@ -67,7 +67,7 @@ CREATE TABLE IF NOT EXISTS task_res( task_id TEXT UNIQUE, group_id TEXT, - workload_id INTEGER, + run_id INTEGER, producer_anonymous BOOLEAN, producer_node_id INTEGER, consumer_anonymous BOOLEAN, @@ -76,9 +76,9 @@ delivered_at TEXT, ttl TEXT, ancestry TEXT, - legacy_server_message BLOB, - legacy_client_message BLOB, - FOREIGN KEY(workload_id) REFERENCES workload(workload_id) + task_type TEXT, + recordset BLOB, + FOREIGN KEY(run_id) REFERENCES run(run_id) ); """ @@ -119,7 +119,7 @@ def initialize(self, log_queries: bool = False) -> List[Tuple[str]]: cur = self.conn.cursor() # Create each table if not exists queries - cur.execute(SQL_CREATE_TABLE_WORKLOAD) + cur.execute(SQL_CREATE_TABLE_RUN) cur.execute(SQL_CREATE_TABLE_TASK_INS) cur.execute(SQL_CREATE_TABLE_TASK_RES) cur.execute(SQL_CREATE_TABLE_NODE) @@ -134,7 +134,7 @@ def query( ) -> List[Dict[str, Any]]: """Execute a SQL query.""" if self.conn is None: - raise Exception("State is not initialized.") + raise AttributeError("State is not initialized.") if data is None: data = [] @@ -198,12 +198,12 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: columns = ", ".join([f":{key}" for key in data[0]]) query = f"INSERT INTO task_ins VALUES({columns});" - # Only invalid workload_id can trigger IntegrityError. + # Only invalid run_id can trigger IntegrityError. # This may need to be changed in the future version with more integrity checks. try: self.query(query, data) except sqlite3.IntegrityError: - log(ERROR, "`workload` is invalid") + log(ERROR, "`run` is invalid") return None return task_id @@ -333,12 +333,12 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: columns = ", ".join([f":{key}" for key in data[0]]) query = f"INSERT INTO task_res VALUES({columns});" - # Only invalid workload_id can trigger IntegrityError. + # Only invalid run_id can trigger IntegrityError. # This may need to be changed in the future version with more integrity checks. try: self.query(query, data) except sqlite3.IntegrityError: - log(ERROR, "`workload` is invalid") + log(ERROR, "`run` is invalid") return None return task_id @@ -459,7 +459,7 @@ def delete_tasks(self, task_ids: Set[UUID]) -> None: """ if self.conn is None: - raise Exception("State not intitialized") + raise AttributeError("State not intitialized") with self.conn: self.conn.execute(query_1, data) @@ -485,17 +485,17 @@ def delete_node(self, node_id: int) -> None: query = "DELETE FROM node WHERE node_id = :node_id;" self.query(query, {"node_id": node_id}) - def get_nodes(self, workload_id: int) -> Set[int]: + def get_nodes(self, run_id: int) -> Set[int]: """Retrieve all currently stored node IDs as a set. Constraints ----------- - If the provided `workload_id` does not exist or has no matching nodes, + If the provided `run_id` does not exist or has no matching nodes, an empty `Set` MUST be returned. """ - # Validate workload ID - query = "SELECT COUNT(*) FROM workload WHERE workload_id = ?;" - if self.query(query, (workload_id,))[0]["COUNT(*)"] == 0: + # Validate run ID + query = "SELECT COUNT(*) FROM run WHERE run_id = ?;" + if self.query(query, (run_id,))[0]["COUNT(*)"] == 0: return set() # Get nodes @@ -504,19 +504,19 @@ def get_nodes(self, workload_id: int) -> Set[int]: result: Set[int] = {row["node_id"] for row in rows} return result - def create_workload(self) -> int: - """Create one workload and store it in state.""" - # Sample a random int64 as workload_id - workload_id: int = int.from_bytes(os.urandom(8), "little", signed=True) + def create_run(self) -> int: + """Create one run and store it in state.""" + # Sample a random int64 as run_id + run_id: int = int.from_bytes(os.urandom(8), "little", signed=True) # Check conflicts - query = "SELECT COUNT(*) FROM workload WHERE workload_id = ?;" - # If workload_id does not exist - if self.query(query, (workload_id,))[0]["COUNT(*)"] == 0: - query = "INSERT INTO workload VALUES(:workload_id);" - self.query(query, {"workload_id": workload_id}) - return workload_id - log(ERROR, "Unexpected workload creation failure.") + query = "SELECT COUNT(*) FROM run WHERE run_id = ?;" + # If run_id does not exist + if self.query(query, (run_id,))[0]["COUNT(*)"] == 0: + query = "INSERT INTO run VALUES(:run_id);" + self.query(query, {"run_id": run_id}) + return run_id + log(ERROR, "Unexpected run creation failure.") return 0 @@ -537,7 +537,7 @@ def task_ins_to_dict(task_msg: TaskIns) -> Dict[str, Any]: result = { "task_id": task_msg.task_id, "group_id": task_msg.group_id, - "workload_id": task_msg.workload_id, + "run_id": task_msg.run_id, "producer_anonymous": task_msg.task.producer.anonymous, "producer_node_id": task_msg.task.producer.node_id, "consumer_anonymous": task_msg.task.consumer.anonymous, @@ -546,10 +546,8 @@ def task_ins_to_dict(task_msg: TaskIns) -> Dict[str, Any]: "delivered_at": task_msg.task.delivered_at, "ttl": task_msg.task.ttl, "ancestry": ",".join(task_msg.task.ancestry), - "legacy_server_message": ( - task_msg.task.legacy_server_message.SerializeToString() - ), - "legacy_client_message": None, + "task_type": task_msg.task.task_type, + "recordset": task_msg.task.recordset.SerializeToString(), } return result @@ -559,7 +557,7 @@ def task_res_to_dict(task_msg: TaskRes) -> Dict[str, Any]: result = { "task_id": task_msg.task_id, "group_id": task_msg.group_id, - "workload_id": task_msg.workload_id, + "run_id": task_msg.run_id, "producer_anonymous": task_msg.task.producer.anonymous, "producer_node_id": task_msg.task.producer.node_id, "consumer_anonymous": task_msg.task.consumer.anonymous, @@ -568,23 +566,21 @@ def task_res_to_dict(task_msg: TaskRes) -> Dict[str, Any]: "delivered_at": task_msg.task.delivered_at, "ttl": task_msg.task.ttl, "ancestry": ",".join(task_msg.task.ancestry), - "legacy_server_message": None, - "legacy_client_message": ( - task_msg.task.legacy_client_message.SerializeToString() - ), + "task_type": task_msg.task.task_type, + "recordset": task_msg.task.recordset.SerializeToString(), } return result def dict_to_task_ins(task_dict: Dict[str, Any]) -> TaskIns: """Turn task_dict into protobuf message.""" - server_message = ServerMessage() - server_message.ParseFromString(task_dict["legacy_server_message"]) + recordset = RecordSet() + recordset.ParseFromString(task_dict["recordset"]) result = TaskIns( task_id=task_dict["task_id"], group_id=task_dict["group_id"], - workload_id=task_dict["workload_id"], + run_id=task_dict["run_id"], task=Task( producer=Node( node_id=task_dict["producer_node_id"], @@ -598,7 +594,8 @@ def dict_to_task_ins(task_dict: Dict[str, Any]) -> TaskIns: delivered_at=task_dict["delivered_at"], ttl=task_dict["ttl"], ancestry=task_dict["ancestry"].split(","), - legacy_server_message=server_message, + task_type=task_dict["task_type"], + recordset=recordset, ), ) return result @@ -606,13 +603,13 @@ def dict_to_task_ins(task_dict: Dict[str, Any]) -> TaskIns: def dict_to_task_res(task_dict: Dict[str, Any]) -> TaskRes: """Turn task_dict into protobuf message.""" - client_message = ClientMessage() - client_message.ParseFromString(task_dict["legacy_client_message"]) + recordset = RecordSet() + recordset.ParseFromString(task_dict["recordset"]) result = TaskRes( task_id=task_dict["task_id"], group_id=task_dict["group_id"], - workload_id=task_dict["workload_id"], + run_id=task_dict["run_id"], task=Task( producer=Node( node_id=task_dict["producer_node_id"], @@ -626,7 +623,8 @@ def dict_to_task_res(task_dict: Dict[str, Any]) -> TaskRes: delivered_at=task_dict["delivered_at"], ttl=task_dict["ttl"], ancestry=task_dict["ancestry"].split(","), - legacy_client_message=client_message, + task_type=task_dict["task_type"], + recordset=recordset, ), ) return result diff --git a/src/py/flwr/server/state/sqlite_state_test.py b/src/py/flwr/server/superlink/state/sqlite_state_test.py similarity index 83% rename from src/py/flwr/server/state/sqlite_state_test.py rename to src/py/flwr/server/superlink/state/sqlite_state_test.py index da8fead1438e..9eef71e396e3 100644 --- a/src/py/flwr/server/state/sqlite_state_test.py +++ b/src/py/flwr/server/superlink/state/sqlite_state_test.py @@ -13,12 +13,12 @@ # limitations under the License. # ============================================================================== """Test for utility functions.""" -# pylint: disable=no-self-use, invalid-name, disable=R0904 +# pylint: disable=invalid-name, disable=R0904 import unittest -from flwr.server.state.sqlite_state import task_ins_to_dict -from flwr.server.state.state_test import create_task_ins +from flwr.server.superlink.state.sqlite_state import task_ins_to_dict +from flwr.server.superlink.state.state_test import create_task_ins class SqliteStateTest(unittest.TestCase): @@ -27,11 +27,11 @@ class SqliteStateTest(unittest.TestCase): def test_ins_res_to_dict(self) -> None: """Check if all required keys are included in return value.""" # Prepare - ins_res = create_task_ins(consumer_node_id=1, anonymous=True, workload_id=0) + ins_res = create_task_ins(consumer_node_id=1, anonymous=True, run_id=0) expected_keys = [ "task_id", "group_id", - "workload_id", + "run_id", "producer_anonymous", "producer_node_id", "consumer_anonymous", @@ -40,8 +40,8 @@ def test_ins_res_to_dict(self) -> None: "delivered_at", "ttl", "ancestry", - "legacy_server_message", - "legacy_client_message", + "task_type", + "recordset", ] # Execute diff --git a/src/py/flwr/server/state/state.py b/src/py/flwr/server/superlink/state/state.py similarity index 93% rename from src/py/flwr/server/state/state.py rename to src/py/flwr/server/superlink/state/state.py index fd8bbc8e8e25..9337ae6d8624 100644 --- a/src/py/flwr/server/state/state.py +++ b/src/py/flwr/server/superlink/state/state.py @@ -19,7 +19,7 @@ from typing import List, Optional, Set from uuid import UUID -from flwr.proto.task_pb2 import TaskIns, TaskRes +from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 class State(abc.ABC): @@ -43,7 +43,7 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: If `task_ins.task.consumer.anonymous` is `False`, then `task_ins.task.consumer.node_id` MUST be set (not 0) - If `task_ins.workload_id` is invalid, then + If `task_ins.run_id` is invalid, then storing the `task_ins` MUST fail. """ @@ -92,7 +92,7 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: If `task_res.task.consumer.anonymous` is `False`, then `task_res.task.consumer.node_id` MUST be set (not 0) - If `task_res.workload_id` is invalid, then + If `task_res.run_id` is invalid, then storing the `task_res` MUST fail. """ @@ -140,15 +140,15 @@ def delete_node(self, node_id: int) -> None: """Remove `node_id` from state.""" @abc.abstractmethod - def get_nodes(self, workload_id: int) -> Set[int]: + def get_nodes(self, run_id: int) -> Set[int]: """Retrieve all currently stored node IDs as a set. Constraints ----------- - If the provided `workload_id` does not exist or has no matching nodes, + If the provided `run_id` does not exist or has no matching nodes, an empty `Set` MUST be returned. """ @abc.abstractmethod - def create_workload(self) -> int: - """Create one workload.""" + def create_run(self) -> int: + """Create one run.""" diff --git a/src/py/flwr/server/state/state_factory.py b/src/py/flwr/server/superlink/state/state_factory.py similarity index 100% rename from src/py/flwr/server/state/state_factory.py rename to src/py/flwr/server/superlink/state/state_factory.py diff --git a/src/py/flwr/server/state/state_test.py b/src/py/flwr/server/superlink/state/state_test.py similarity index 81% rename from src/py/flwr/server/state/state_test.py rename to src/py/flwr/server/superlink/state/state_test.py index 59299451c3d8..d0470a7ce7f7 100644 --- a/src/py/flwr/server/state/state_test.py +++ b/src/py/flwr/server/superlink/state/state_test.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== """Tests all state implemenations have to conform to.""" -# pylint: disable=no-self-use, invalid-name, disable=R0904 +# pylint: disable=invalid-name, disable=R0904 import tempfile import unittest @@ -22,10 +22,10 @@ from typing import List from uuid import uuid4 -from flwr.proto.node_pb2 import Node -from flwr.proto.task_pb2 import Task, TaskIns, TaskRes -from flwr.proto.transport_pb2 import ClientMessage, ServerMessage -from flwr.server.state import InMemoryState, SqliteState, State +from flwr.proto.node_pb2 import Node # pylint: disable=E0611 +from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611 +from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 +from flwr.server.superlink.state import InMemoryState, SqliteState, State class StateTest(unittest.TestCase): @@ -66,9 +66,9 @@ def test_store_task_ins_one(self) -> None: # Prepare consumer_node_id = 1 state = self.state_factory() - workload_id = state.create_workload() + run_id = state.create_run() task_ins = create_task_ins( - consumer_node_id=consumer_node_id, anonymous=False, workload_id=workload_id + consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id ) assert task_ins.task.created_at == "" # pylint: disable=no-member @@ -108,15 +108,15 @@ def test_store_and_delete_tasks(self) -> None: # Prepare consumer_node_id = 1 state = self.state_factory() - workload_id = state.create_workload() + run_id = state.create_run() task_ins_0 = create_task_ins( - consumer_node_id=consumer_node_id, anonymous=False, workload_id=workload_id + consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id ) task_ins_1 = create_task_ins( - consumer_node_id=consumer_node_id, anonymous=False, workload_id=workload_id + consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id ) task_ins_2 = create_task_ins( - consumer_node_id=consumer_node_id, anonymous=False, workload_id=workload_id + consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id ) # Insert three TaskIns @@ -136,7 +136,7 @@ def test_store_and_delete_tasks(self) -> None: producer_node_id=100, anonymous=False, ancestry=[str(task_id_0)], - workload_id=workload_id, + run_id=run_id, ) _ = state.store_task_res(task_res=task_res_0) @@ -147,7 +147,7 @@ def test_store_and_delete_tasks(self) -> None: producer_node_id=100, anonymous=False, ancestry=[str(task_id_1)], - workload_id=workload_id, + run_id=run_id, ) _ = state.store_task_res(task_res=task_res_1) @@ -182,10 +182,8 @@ def test_task_ins_store_anonymous_and_retrieve_anonymous(self) -> None: """ # Prepare state: State = self.state_factory() - workload_id = state.create_workload() - task_ins = create_task_ins( - consumer_node_id=0, anonymous=True, workload_id=workload_id - ) + run_id = state.create_run() + task_ins = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) # Execute task_ins_uuid = state.store_task_ins(task_ins) @@ -199,10 +197,8 @@ def test_task_ins_store_anonymous_and_fail_retrieving_identitiy(self) -> None: """Store anonymous TaskIns and fail to retrieve it.""" # Prepare state: State = self.state_factory() - workload_id = state.create_workload() - task_ins = create_task_ins( - consumer_node_id=0, anonymous=True, workload_id=workload_id - ) + run_id = state.create_run() + task_ins = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) # Execute _ = state.store_task_ins(task_ins) @@ -215,10 +211,8 @@ def test_task_ins_store_identity_and_fail_retrieving_anonymous(self) -> None: """Store identity TaskIns and fail retrieving it as anonymous.""" # Prepare state: State = self.state_factory() - workload_id = state.create_workload() - task_ins = create_task_ins( - consumer_node_id=1, anonymous=False, workload_id=workload_id - ) + run_id = state.create_run() + task_ins = create_task_ins(consumer_node_id=1, anonymous=False, run_id=run_id) # Execute _ = state.store_task_ins(task_ins) @@ -231,10 +225,8 @@ def test_task_ins_store_identity_and_retrieve_identity(self) -> None: """Store identity TaskIns and retrieve it.""" # Prepare state: State = self.state_factory() - workload_id = state.create_workload() - task_ins = create_task_ins( - consumer_node_id=1, anonymous=False, workload_id=workload_id - ) + run_id = state.create_run() + task_ins = create_task_ins(consumer_node_id=1, anonymous=False, run_id=run_id) # Execute task_ins_uuid = state.store_task_ins(task_ins) @@ -250,10 +242,8 @@ def test_task_ins_store_delivered_and_fail_retrieving(self) -> None: """Fail retrieving delivered task.""" # Prepare state: State = self.state_factory() - workload_id = state.create_workload() - task_ins = create_task_ins( - consumer_node_id=1, anonymous=False, workload_id=workload_id - ) + run_id = state.create_run() + task_ins = create_task_ins(consumer_node_id=1, anonymous=False, run_id=run_id) # Execute _ = state.store_task_ins(task_ins) @@ -278,13 +268,11 @@ def test_get_task_ins_limit_throws_for_limit_zero(self) -> None: with self.assertRaises(AssertionError): state.get_task_ins(node_id=1, limit=0) - def test_task_ins_store_invalid_workload_id_and_fail(self) -> None: - """Store TaskIns with invalid workload_id and fail.""" + def test_task_ins_store_invalid_run_id_and_fail(self) -> None: + """Store TaskIns with invalid run_id and fail.""" # Prepare state: State = self.state_factory() - task_ins = create_task_ins( - consumer_node_id=0, anonymous=True, workload_id=61016 - ) + task_ins = create_task_ins(consumer_node_id=0, anonymous=True, run_id=61016) # Execute task_id = state.store_task_ins(task_ins) @@ -297,13 +285,13 @@ def test_task_res_store_and_retrieve_by_task_ins_id(self) -> None: """Store TaskRes retrieve it by task_ins_id.""" # Prepare state: State = self.state_factory() - workload_id = state.create_workload() + run_id = state.create_run() task_ins_id = uuid4() task_res = create_task_res( producer_node_id=0, anonymous=True, ancestry=[str(task_ins_id)], - workload_id=workload_id, + run_id=run_id, ) # Execute @@ -318,10 +306,10 @@ def test_node_ids_initial_state(self) -> None: """Test retrieving all node_ids and empty initial state.""" # Prepare state: State = self.state_factory() - workload_id = state.create_workload() + run_id = state.create_run() # Execute - retrieved_node_ids = state.get_nodes(workload_id) + retrieved_node_ids = state.get_nodes(run_id) # Assert assert len(retrieved_node_ids) == 0 @@ -330,13 +318,13 @@ def test_create_node_and_get_nodes(self) -> None: """Test creating a client node.""" # Prepare state: State = self.state_factory() - workload_id = state.create_workload() + run_id = state.create_run() node_ids = [] # Execute for _ in range(10): node_ids.append(state.create_node()) - retrieved_node_ids = state.get_nodes(workload_id) + retrieved_node_ids = state.get_nodes(run_id) # Assert for i in retrieved_node_ids: @@ -346,26 +334,26 @@ def test_delete_node(self) -> None: """Test deleting a client node.""" # Prepare state: State = self.state_factory() - workload_id = state.create_workload() + run_id = state.create_run() node_id = state.create_node() # Execute state.delete_node(node_id) - retrieved_node_ids = state.get_nodes(workload_id) + retrieved_node_ids = state.get_nodes(run_id) # Assert assert len(retrieved_node_ids) == 0 - def test_get_nodes_invalid_workload_id(self) -> None: - """Test retrieving all node_ids with invalid workload_id.""" + def test_get_nodes_invalid_run_id(self) -> None: + """Test retrieving all node_ids with invalid run_id.""" # Prepare state: State = self.state_factory() - state.create_workload() - invalid_workload_id = 61016 + state.create_run() + invalid_run_id = 61016 state.create_node() # Execute - retrieved_node_ids = state.get_nodes(invalid_workload_id) + retrieved_node_ids = state.get_nodes(invalid_run_id) # Assert assert len(retrieved_node_ids) == 0 @@ -374,13 +362,9 @@ def test_num_task_ins(self) -> None: """Test if num_tasks returns correct number of not delivered task_ins.""" # Prepare state: State = self.state_factory() - workload_id = state.create_workload() - task_0 = create_task_ins( - consumer_node_id=0, anonymous=True, workload_id=workload_id - ) - task_1 = create_task_ins( - consumer_node_id=0, anonymous=True, workload_id=workload_id - ) + run_id = state.create_run() + task_0 = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) + task_1 = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) # Store two tasks state.store_task_ins(task_0) @@ -396,12 +380,12 @@ def test_num_task_res(self) -> None: """Test if num_tasks returns correct number of not delivered task_res.""" # Prepare state: State = self.state_factory() - workload_id = state.create_workload() + run_id = state.create_run() task_0 = create_task_res( - producer_node_id=0, anonymous=True, ancestry=["1"], workload_id=workload_id + producer_node_id=0, anonymous=True, ancestry=["1"], run_id=run_id ) task_1 = create_task_res( - producer_node_id=0, anonymous=True, ancestry=["1"], workload_id=workload_id + producer_node_id=0, anonymous=True, ancestry=["1"], run_id=run_id ) # Store two tasks @@ -418,7 +402,7 @@ def test_num_task_res(self) -> None: def create_task_ins( consumer_node_id: int, anonymous: bool, - workload_id: int, + run_id: int, delivered_at: str = "", ) -> TaskIns: """Create a TaskIns for testing.""" @@ -429,14 +413,13 @@ def create_task_ins( task = TaskIns( task_id="", group_id="", - workload_id=workload_id, + run_id=run_id, task=Task( delivered_at=delivered_at, producer=Node(node_id=0, anonymous=True), consumer=consumer, - legacy_server_message=ServerMessage( - reconnect_ins=ServerMessage.ReconnectIns() - ), + task_type="mock", + recordset=RecordSet(parameters={}, metrics={}, configs={}), ), ) return task @@ -446,20 +429,19 @@ def create_task_res( producer_node_id: int, anonymous: bool, ancestry: List[str], - workload_id: int, + run_id: int, ) -> TaskRes: """Create a TaskRes for testing.""" task_res = TaskRes( task_id="", group_id="", - workload_id=workload_id, + run_id=run_id, task=Task( producer=Node(node_id=producer_node_id, anonymous=anonymous), consumer=Node(node_id=0, anonymous=True), ancestry=ancestry, - legacy_client_message=ClientMessage( - disconnect_res=ClientMessage.DisconnectRes() - ), + task_type="mock", + recordset=RecordSet(parameters={}, metrics={}, configs={}), ), ) return task_res diff --git a/src/py/flwr/server/typing.py b/src/py/flwr/server/typing.py new file mode 100644 index 000000000000..01143af74392 --- /dev/null +++ b/src/py/flwr/server/typing.py @@ -0,0 +1,25 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Custom types for Flower servers.""" + + +from typing import Callable + +from flwr.common import Context + +from .driver import Driver + +ServerAppCallable = Callable[[Driver, Context], None] +Workflow = Callable[[Driver, Context], None] diff --git a/src/py/flwr/server/utils/validator.py b/src/py/flwr/server/utils/validator.py index fd89a01e4a4e..f9b271beafdc 100644 --- a/src/py/flwr/server/utils/validator.py +++ b/src/py/flwr/server/utils/validator.py @@ -17,7 +17,7 @@ from typing import List, Union -from flwr.proto.task_pb2 import TaskIns, TaskRes +from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 # pylint: disable-next=too-many-branches,too-many-statements @@ -64,21 +64,10 @@ def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> List[str validation_errors.append("non-anonymous consumer MUST provide a `node_id`") # Content check - has_fields = { - "sa": tasks_ins_res.task.HasField("sa"), - "legacy_server_message": tasks_ins_res.task.HasField( - "legacy_server_message" - ), - } - if not (has_fields["sa"] or has_fields["legacy_server_message"]): - err_msg = ", ".join([f"`{field}`" for field in has_fields]) - validation_errors.append( - f"`task` in `TaskIns` must set at least one of fields {{{err_msg}}}" - ) - if has_fields[ - "legacy_server_message" - ] and not tasks_ins_res.task.legacy_server_message.HasField("msg"): - validation_errors.append("`legacy_server_message` does not set field `msg`") + if tasks_ins_res.task.task_type == "": + validation_errors.append("`task_type` MUST be set") + if not tasks_ins_res.task.HasField("recordset"): + validation_errors.append("`recordset` MUST be set") # Ancestors if len(tasks_ins_res.task.ancestry) != 0: @@ -115,21 +104,10 @@ def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> List[str validation_errors.append("non-anonymous consumer MUST provide a `node_id`") # Content check - has_fields = { - "sa": tasks_ins_res.task.HasField("sa"), - "legacy_client_message": tasks_ins_res.task.HasField( - "legacy_client_message" - ), - } - if not (has_fields["sa"] or has_fields["legacy_client_message"]): - err_msg = ", ".join([f"`{field}`" for field in has_fields]) - validation_errors.append( - f"`task` in `TaskRes` must set at least one of fields {{{err_msg}}}" - ) - if has_fields[ - "legacy_client_message" - ] and not tasks_ins_res.task.legacy_client_message.HasField("msg"): - validation_errors.append("`legacy_client_message` does not set field `msg`") + if tasks_ins_res.task.task_type == "": + validation_errors.append("`task_type` MUST be set") + if not tasks_ins_res.task.HasField("recordset"): + validation_errors.append("`recordset` MUST be set") # Ancestors if len(tasks_ins_res.task.ancestry) == 0: diff --git a/src/py/flwr/server/utils/validator_test.py b/src/py/flwr/server/utils/validator_test.py index cab51fbf46de..8e0849508020 100644 --- a/src/py/flwr/server/utils/validator_test.py +++ b/src/py/flwr/server/utils/validator_test.py @@ -18,9 +18,9 @@ import unittest from typing import List, Tuple -from flwr.proto.node_pb2 import Node -from flwr.proto.task_pb2 import SecureAggregation, Task, TaskIns, TaskRes -from flwr.proto.transport_pb2 import ClientMessage, ServerMessage +from flwr.proto.node_pb2 import Node # pylint: disable=E0611 +from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611 +from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 from .validator import validate_task_ins_or_res @@ -37,16 +37,12 @@ def test_task_ins(self) -> None: # Execute & Assert for consumer_node_id, anonymous in valid_ins: - msg = create_task_ins( - consumer_node_id, anonymous, has_legacy_server_message=True - ) + msg = create_task_ins(consumer_node_id, anonymous) val_errors = validate_task_ins_or_res(msg) self.assertFalse(val_errors) for consumer_node_id, anonymous in invalid_ins: - msg = create_task_ins( - consumer_node_id, anonymous, has_legacy_server_message=True - ) + msg = create_task_ins(consumer_node_id, anonymous) val_errors = validate_task_ins_or_res(msg) self.assertTrue(val_errors) @@ -70,61 +66,19 @@ def test_is_valid_task_res(self) -> None: # Execute & Assert for producer_node_id, anonymous, ancestry in valid_res: - msg = create_task_res( - producer_node_id, anonymous, ancestry, has_legacy_client_message=True - ) + msg = create_task_res(producer_node_id, anonymous, ancestry) val_errors = validate_task_ins_or_res(msg) self.assertFalse(val_errors) for producer_node_id, anonymous, ancestry in invalid_res: - msg = create_task_res( - producer_node_id, anonymous, ancestry, has_legacy_client_message=True - ) + msg = create_task_res(producer_node_id, anonymous, ancestry) val_errors = validate_task_ins_or_res(msg) self.assertTrue(val_errors, (producer_node_id, anonymous, ancestry)) - def test_task_ins_secure_aggregation(self) -> None: - """Test is_valid task_ins for Secure Aggregation.""" - # Prepare - # (has_legacy_server_message, has_sa) - valid_ins = [(True, True), (False, True)] - invalid_ins = [(False, False)] - - # Execute & Assert - for has_legacy_server_message, has_sa in valid_ins: - msg = create_task_ins(1, False, has_legacy_server_message, has_sa) - val_errors = validate_task_ins_or_res(msg) - self.assertFalse(val_errors) - - for has_legacy_server_message, has_sa in invalid_ins: - msg = create_task_ins(1, False, has_legacy_server_message, has_sa) - val_errors = validate_task_ins_or_res(msg) - self.assertTrue(val_errors) - - def test_task_res_secure_aggregation(self) -> None: - """Test is_valid task_res for Secure Aggregation.""" - # Prepare - # (has_legacy_server_message, has_sa) - valid_res = [(True, True), (False, True)] - invalid_res = [(False, False)] - - # Execute & Assert - for has_legacy_client_message, has_sa in valid_res: - msg = create_task_res(0, True, ["1"], has_legacy_client_message, has_sa) - val_errors = validate_task_ins_or_res(msg) - self.assertFalse(val_errors) - - for has_legacy_client_message, has_sa in invalid_res: - msg = create_task_res(0, True, ["1"], has_legacy_client_message, has_sa) - val_errors = validate_task_ins_or_res(msg) - self.assertTrue(val_errors) - def create_task_ins( consumer_node_id: int, anonymous: bool, - has_legacy_server_message: bool = False, - has_sa: bool = False, delivered_at: str = "", ) -> TaskIns: """Create a TaskIns for testing.""" @@ -135,17 +89,13 @@ def create_task_ins( task = TaskIns( task_id="", group_id="", - workload_id=0, + run_id=0, task=Task( delivered_at=delivered_at, producer=Node(node_id=0, anonymous=True), consumer=consumer, - legacy_server_message=ServerMessage( - reconnect_ins=ServerMessage.ReconnectIns() - ) - if has_legacy_server_message - else None, - sa=SecureAggregation(named_values={}) if has_sa else None, + task_type="mock", + recordset=RecordSet(parameters={}, metrics={}, configs={}), ), ) return task @@ -155,24 +105,18 @@ def create_task_res( producer_node_id: int, anonymous: bool, ancestry: List[str], - has_legacy_client_message: bool = False, - has_sa: bool = False, ) -> TaskRes: """Create a TaskRes for testing.""" task_res = TaskRes( task_id="", group_id="", - workload_id=0, + run_id=0, task=Task( producer=Node(node_id=producer_node_id, anonymous=anonymous), consumer=Node(node_id=0, anonymous=True), ancestry=ancestry, - legacy_client_message=ClientMessage( - disconnect_res=ClientMessage.DisconnectRes() - ) - if has_legacy_client_message - else None, - sa=SecureAggregation(named_values={}) if has_sa else None, + task_type="mock", + recordset=RecordSet(parameters={}, metrics={}, configs={}), ), ) return task_res diff --git a/src/py/flwr/server/workflow/__init__.py b/src/py/flwr/server/workflow/__init__.py new file mode 100644 index 000000000000..31dee89a185d --- /dev/null +++ b/src/py/flwr/server/workflow/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Workflows.""" + + +from .default_workflows import DefaultWorkflow +from .secure_aggregation import SecAggPlusWorkflow, SecAggWorkflow + +__all__ = [ + "DefaultWorkflow", + "SecAggPlusWorkflow", + "SecAggWorkflow", +] diff --git a/src/py/flwr/server/workflow/constant.py b/src/py/flwr/server/workflow/constant.py new file mode 100644 index 000000000000..068e05b27e12 --- /dev/null +++ b/src/py/flwr/server/workflow/constant.py @@ -0,0 +1,32 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Constants for default workflows.""" + + +from __future__ import annotations + +MAIN_CONFIGS_RECORD = "config" +MAIN_PARAMS_RECORD = "parameters" + + +class Key: + """Constants for default workflows.""" + + CURRENT_ROUND = "current_round" + START_TIME = "start_time" + + def __new__(cls) -> Key: + """Prevent instantiation.""" + raise TypeError(f"{cls.__name__} cannot be instantiated.") diff --git a/src/py/flwr/server/workflow/default_workflows.py b/src/py/flwr/server/workflow/default_workflows.py new file mode 100644 index 000000000000..876ae56dcadc --- /dev/null +++ b/src/py/flwr/server/workflow/default_workflows.py @@ -0,0 +1,347 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Legacy default workflows.""" + + +import io +import timeit +from logging import INFO +from typing import Optional, cast + +import flwr.common.recordset_compat as compat +from flwr.common import ConfigsRecord, Context, GetParametersIns, log +from flwr.common.constant import MessageType, MessageTypeLegacy + +from ..compat.app_utils import start_update_client_manager_thread +from ..compat.legacy_context import LegacyContext +from ..driver import Driver +from ..typing import Workflow +from .constant import MAIN_CONFIGS_RECORD, MAIN_PARAMS_RECORD, Key + + +class DefaultWorkflow: + """Default workflow in Flower.""" + + def __init__( + self, + fit_workflow: Optional[Workflow] = None, + evaluate_workflow: Optional[Workflow] = None, + ) -> None: + if fit_workflow is None: + fit_workflow = default_fit_workflow + if evaluate_workflow is None: + evaluate_workflow = default_evaluate_workflow + self.fit_workflow: Workflow = fit_workflow + self.evaluate_workflow: Workflow = evaluate_workflow + + def __call__(self, driver: Driver, context: Context) -> None: + """Execute the workflow.""" + if not isinstance(context, LegacyContext): + raise TypeError( + f"Expect a LegacyContext, but get {type(context).__name__}." + ) + + # Start the thread updating nodes + thread, f_stop = start_update_client_manager_thread( + driver, context.client_manager + ) + + # Initialize parameters + log(INFO, "[INIT]") + default_init_params_workflow(driver, context) + + # Run federated learning for num_rounds + start_time = timeit.default_timer() + cfg = ConfigsRecord() + cfg[Key.START_TIME] = start_time + context.state.configs_records[MAIN_CONFIGS_RECORD] = cfg + + for current_round in range(1, context.config.num_rounds + 1): + log(INFO, "") + log(INFO, "[ROUND %s]", current_round) + cfg[Key.CURRENT_ROUND] = current_round + + # Fit round + self.fit_workflow(driver, context) + + # Centralized evaluation + default_centralized_evaluation_workflow(driver, context) + + # Evaluate round + self.evaluate_workflow(driver, context) + + # Bookkeeping and log results + end_time = timeit.default_timer() + elapsed = end_time - start_time + hist = context.history + log(INFO, "") + log(INFO, "[SUMMARY]") + log(INFO, "Run finished %s rounds in %.2fs", context.config.num_rounds, elapsed) + for idx, line in enumerate(io.StringIO(str(hist))): + if idx == 0: + log(INFO, "%s", line.strip("\n")) + else: + log(INFO, "\t%s", line.strip("\n")) + log(INFO, "") + + # Terminate the thread + f_stop.set() + thread.join() + + +def default_init_params_workflow(driver: Driver, context: Context) -> None: + """Execute the default workflow for parameters initialization.""" + if not isinstance(context, LegacyContext): + raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.") + + parameters = context.strategy.initialize_parameters( + client_manager=context.client_manager + ) + if parameters is not None: + log(INFO, "Using initial global parameters provided by strategy") + paramsrecord = compat.parameters_to_parametersrecord( + parameters, keep_input=True + ) + else: + # Get initial parameters from one of the clients + log(INFO, "Requesting initial parameters from one random client") + random_client = context.client_manager.sample(1)[0] + # Send GetParametersIns and get the response + content = compat.getparametersins_to_recordset(GetParametersIns({})) + messages = driver.send_and_receive( + [ + driver.create_message( + content=content, + message_type=MessageTypeLegacy.GET_PARAMETERS, + dst_node_id=random_client.node_id, + group_id="0", + ttl="", + ) + ] + ) + log(INFO, "Received initial parameters from one random client") + msg = list(messages)[0] + paramsrecord = next(iter(msg.content.parameters_records.values())) + + context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord + + # Evaluate initial parameters + log(INFO, "Evaluating initial global parameters") + parameters = compat.parametersrecord_to_parameters(paramsrecord, keep_input=True) + res = context.strategy.evaluate(0, parameters=parameters) + if res is not None: + log( + INFO, + "initial parameters (loss, other metrics): %s, %s", + res[0], + res[1], + ) + context.history.add_loss_centralized(server_round=0, loss=res[0]) + context.history.add_metrics_centralized(server_round=0, metrics=res[1]) + + +def default_centralized_evaluation_workflow(_: Driver, context: Context) -> None: + """Execute the default workflow for centralized evaluation.""" + if not isinstance(context, LegacyContext): + raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.") + + # Retrieve current_round and start_time from the context + cfg = context.state.configs_records[MAIN_CONFIGS_RECORD] + current_round = cast(int, cfg[Key.CURRENT_ROUND]) + start_time = cast(float, cfg[Key.START_TIME]) + + # Centralized evaluation + parameters = compat.parametersrecord_to_parameters( + record=context.state.parameters_records[MAIN_PARAMS_RECORD], + keep_input=True, + ) + res_cen = context.strategy.evaluate(current_round, parameters=parameters) + if res_cen is not None: + loss_cen, metrics_cen = res_cen + log( + INFO, + "fit progress: (%s, %s, %s, %s)", + current_round, + loss_cen, + metrics_cen, + timeit.default_timer() - start_time, + ) + context.history.add_loss_centralized(server_round=current_round, loss=loss_cen) + context.history.add_metrics_centralized( + server_round=current_round, metrics=metrics_cen + ) + + +def default_fit_workflow( # pylint: disable=R0914 + driver: Driver, context: Context +) -> None: + """Execute the default workflow for a single fit round.""" + if not isinstance(context, LegacyContext): + raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.") + + # Get current_round and parameters + cfg = context.state.configs_records[MAIN_CONFIGS_RECORD] + current_round = cast(int, cfg[Key.CURRENT_ROUND]) + parametersrecord = context.state.parameters_records[MAIN_PARAMS_RECORD] + parameters = compat.parametersrecord_to_parameters( + parametersrecord, keep_input=True + ) + + # Get clients and their respective instructions from strategy + client_instructions = context.strategy.configure_fit( + server_round=current_round, + parameters=parameters, + client_manager=context.client_manager, + ) + + if not client_instructions: + log(INFO, "configure_fit: no clients selected, cancel") + return + log( + INFO, + "configure_fit: strategy sampled %s clients (out of %s)", + len(client_instructions), + context.client_manager.num_available(), + ) + + # Build dictionary mapping node_id to ClientProxy + node_id_to_proxy = {proxy.node_id: proxy for proxy, _ in client_instructions} + + # Build out messages + out_messages = [ + driver.create_message( + content=compat.fitins_to_recordset(fitins, True), + message_type=MessageType.TRAIN, + dst_node_id=proxy.node_id, + group_id=str(current_round), + ttl="", + ) + for proxy, fitins in client_instructions + ] + + # Send instructions to clients and + # collect `fit` results from all clients participating in this round + messages = list(driver.send_and_receive(out_messages)) + del out_messages + num_failures = len([msg for msg in messages if msg.has_error()]) + + # No exception/failure handling currently + log( + INFO, + "aggregate_fit: received %s results and %s failures", + len(messages) - num_failures, + num_failures, + ) + + # Aggregate training results + results = [ + ( + node_id_to_proxy[msg.metadata.src_node_id], + compat.recordset_to_fitres(msg.content, False), + ) + for msg in messages + ] + aggregated_result = context.strategy.aggregate_fit(current_round, results, []) + parameters_aggregated, metrics_aggregated = aggregated_result + + # Update the parameters and write history + if parameters_aggregated: + paramsrecord = compat.parameters_to_parametersrecord( + parameters_aggregated, True + ) + context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord + context.history.add_metrics_distributed_fit( + server_round=current_round, metrics=metrics_aggregated + ) + + +def default_evaluate_workflow(driver: Driver, context: Context) -> None: + """Execute the default workflow for a single evaluate round.""" + if not isinstance(context, LegacyContext): + raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.") + + # Get current_round and parameters + cfg = context.state.configs_records[MAIN_CONFIGS_RECORD] + current_round = cast(int, cfg[Key.CURRENT_ROUND]) + parametersrecord = context.state.parameters_records[MAIN_PARAMS_RECORD] + parameters = compat.parametersrecord_to_parameters( + parametersrecord, keep_input=True + ) + + # Get clients and their respective instructions from strategy + client_instructions = context.strategy.configure_evaluate( + server_round=current_round, + parameters=parameters, + client_manager=context.client_manager, + ) + if not client_instructions: + log(INFO, "configure_evaluate: no clients selected, skipping evaluation") + return + log( + INFO, + "configure_evaluate: strategy sampled %s clients (out of %s)", + len(client_instructions), + context.client_manager.num_available(), + ) + + # Build dictionary mapping node_id to ClientProxy + node_id_to_proxy = {proxy.node_id: proxy for proxy, _ in client_instructions} + + # Build out messages + out_messages = [ + driver.create_message( + content=compat.evaluateins_to_recordset(evalins, True), + message_type=MessageType.EVALUATE, + dst_node_id=proxy.node_id, + group_id=str(current_round), + ttl="", + ) + for proxy, evalins in client_instructions + ] + + # Send instructions to clients and + # collect `evaluate` results from all clients participating in this round + messages = list(driver.send_and_receive(out_messages)) + del out_messages + num_failures = len([msg for msg in messages if msg.has_error()]) + + # No exception/failure handling currently + log( + INFO, + "aggregate_evaluate: received %s results and %s failures", + len(messages) - num_failures, + num_failures, + ) + + # Aggregate the evaluation results + results = [ + ( + node_id_to_proxy[msg.metadata.src_node_id], + compat.recordset_to_evaluateres(msg.content), + ) + for msg in messages + ] + aggregated_result = context.strategy.aggregate_evaluate(current_round, results, []) + + loss_aggregated, metrics_aggregated = aggregated_result + + # Write history + if loss_aggregated is not None: + context.history.add_loss_distributed( + server_round=current_round, loss=loss_aggregated + ) + context.history.add_metrics_distributed( + server_round=current_round, metrics=metrics_aggregated + ) diff --git a/src/py/flwr/client/secure_aggregation/__init__.py b/src/py/flwr/server/workflow/secure_aggregation/__init__.py similarity index 72% rename from src/py/flwr/client/secure_aggregation/__init__.py rename to src/py/flwr/server/workflow/secure_aggregation/__init__.py index 37c816a390de..25e2a32da334 100644 --- a/src/py/flwr/client/secure_aggregation/__init__.py +++ b/src/py/flwr/server/workflow/secure_aggregation/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Secure Aggregation handlers.""" +"""Secure Aggregation workflows.""" -from .handler import SecureAggregationHandler -from .secaggplus_handler import SecAggPlusHandler +from .secagg_workflow import SecAggWorkflow +from .secaggplus_workflow import SecAggPlusWorkflow __all__ = [ - "SecAggPlusHandler", - "SecureAggregationHandler", + "SecAggPlusWorkflow", + "SecAggWorkflow", ] diff --git a/src/py/flwr/server/workflow/secure_aggregation/secagg_workflow.py b/src/py/flwr/server/workflow/secure_aggregation/secagg_workflow.py new file mode 100644 index 000000000000..f56423e4a0d0 --- /dev/null +++ b/src/py/flwr/server/workflow/secure_aggregation/secagg_workflow.py @@ -0,0 +1,112 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Workflow for the SecAgg protocol.""" + + +from typing import Optional, Union + +from .secaggplus_workflow import SecAggPlusWorkflow + + +class SecAggWorkflow(SecAggPlusWorkflow): + """The workflow for the SecAgg protocol. + + The SecAgg protocol ensures the secure summation of integer vectors owned by + multiple parties, without accessing any individual integer vector. This workflow + allows the server to compute the weighted average of model parameters across all + clients, ensuring individual contributions remain private. This is achieved by + clients sending both, a weighting factor and a weighted version of the locally + updated parameters, both of which are masked for privacy. Specifically, each + client uploads "[w, w * params]" with masks, where weighting factor 'w' is the + number of examples ('num_examples') and 'params' represents the model parameters + ('parameters') from the client's `FitRes`. The server then aggregates these + contributions to compute the weighted average of model parameters. + + The protocol involves four main stages: + - 'setup': Send SecAgg configuration to clients and collect their public keys. + - 'share keys': Broadcast public keys among clients and collect encrypted secret + key shares. + - 'collect masked vectors': Forward encrypted secret key shares to target clients + and collect masked model parameters. + - 'unmask': Collect secret key shares to decrypt and aggregate the model parameters. + + Only the aggregated model parameters are exposed and passed to + `Strategy.aggregate_fit`, ensuring individual data privacy. + + Parameters + ---------- + reconstruction_threshold : Union[int, float] + The minimum number of shares required to reconstruct a client's private key, + or, if specified as a float, it represents the proportion of the total number + of shares needed for reconstruction. This threshold ensures privacy by allowing + for the recovery of contributions from dropped clients during aggregation, + without compromising individual client data. + max_weight : Optional[float] (default: 1000.0) + The maximum value of the weight that can be assigned to any single client's + update during the weighted average calculation on the server side, e.g., in the + FedAvg algorithm. + clipping_range : float, optional (default: 8.0) + The range within which model parameters are clipped before quantization. + This parameter ensures each model parameter is bounded within + [-clipping_range, clipping_range], facilitating quantization. + quantization_range : int, optional (default: 4194304, this equals 2**22) + The size of the range into which floating-point model parameters are quantized, + mapping each parameter to an integer in [0, quantization_range-1]. This + facilitates cryptographic operations on the model updates. + modulus_range : int, optional (default: 4294967296, this equals 2**32) + The range of values from which random mask entries are uniformly sampled + ([0, modulus_range-1]). `modulus_range` must be less than 4294967296. + Please use 2**n values for `modulus_range` to prevent overflow issues. + timeout : Optional[float] (default: None) + The timeout duration in seconds. If specified, the workflow will wait for + replies for this duration each time. If `None`, there is no time limit and + the workflow will wait until replies for all messages are received. + + Notes + ----- + - Each client's private key is split into N shares under the SecAgg protocol, where + N is the number of selected clients. + - Generally, higher `reconstruction_threshold` means better privacy guarantees but + less tolerance to dropouts. + - Too large `max_weight` may compromise the precision of the quantization. + - `modulus_range` must be 2**n and larger than `quantization_range`. + - When `reconstruction_threshold` is a float, it is interpreted as the proportion of + the number of all selected clients needed for the reconstruction of a private key. + This feature enables flexibility in setting the security threshold relative to the + number of selected clients. + - `reconstruction_threshold`, and the quantization parameters + (`clipping_range`, `quantization_range`, `modulus_range`) play critical roles in + balancing privacy, robustness, and efficiency within the SecAgg protocol. + """ + + def __init__( # pylint: disable=R0913 + self, + reconstruction_threshold: Union[int, float], + *, + max_weight: float = 1000.0, + clipping_range: float = 8.0, + quantization_range: int = 4194304, + modulus_range: int = 4294967296, + timeout: Optional[float] = None, + ) -> None: + super().__init__( + num_shares=1.0, + reconstruction_threshold=reconstruction_threshold, + max_weight=max_weight, + clipping_range=clipping_range, + quantization_range=quantization_range, + modulus_range=modulus_range, + timeout=timeout, + ) diff --git a/src/py/flwr/server/workflow/secure_aggregation/secaggplus_workflow.py b/src/py/flwr/server/workflow/secure_aggregation/secaggplus_workflow.py new file mode 100644 index 000000000000..42ee9c15f1cd --- /dev/null +++ b/src/py/flwr/server/workflow/secure_aggregation/secaggplus_workflow.py @@ -0,0 +1,673 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Workflow for the SecAgg+ protocol.""" + + +import random +from dataclasses import dataclass, field +from logging import DEBUG, ERROR, INFO, WARN +from typing import Dict, List, Optional, Set, Tuple, Union, cast + +import flwr.common.recordset_compat as compat +from flwr.common import ( + ConfigsRecord, + Context, + FitRes, + Message, + MessageType, + NDArrays, + RecordSet, + bytes_to_ndarray, + log, + ndarrays_to_parameters, +) +from flwr.common.secure_aggregation.crypto.shamir import combine_shares +from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( + bytes_to_private_key, + bytes_to_public_key, + generate_shared_key, +) +from flwr.common.secure_aggregation.ndarrays_arithmetic import ( + factor_extract, + get_parameters_shape, + parameters_addition, + parameters_mod, + parameters_subtraction, +) +from flwr.common.secure_aggregation.quantization import dequantize +from flwr.common.secure_aggregation.secaggplus_constants import ( + RECORD_KEY_CONFIGS, + Key, + Stage, +) +from flwr.common.secure_aggregation.secaggplus_utils import pseudo_rand_gen +from flwr.server.client_proxy import ClientProxy +from flwr.server.compat.legacy_context import LegacyContext +from flwr.server.driver import Driver + +from ..constant import MAIN_CONFIGS_RECORD, MAIN_PARAMS_RECORD +from ..constant import Key as WorkflowKey + + +@dataclass +class WorkflowState: # pylint: disable=R0902 + """The state of the SecAgg+ protocol.""" + + nid_to_proxies: Dict[int, ClientProxy] = field(default_factory=dict) + nid_to_fitins: Dict[int, RecordSet] = field(default_factory=dict) + sampled_node_ids: Set[int] = field(default_factory=set) + active_node_ids: Set[int] = field(default_factory=set) + num_shares: int = 0 + threshold: int = 0 + clipping_range: float = 0.0 + quantization_range: int = 0 + mod_range: int = 0 + max_weight: float = 0.0 + nid_to_neighbours: Dict[int, Set[int]] = field(default_factory=dict) + nid_to_publickeys: Dict[int, List[bytes]] = field(default_factory=dict) + forward_srcs: Dict[int, List[int]] = field(default_factory=dict) + forward_ciphertexts: Dict[int, List[bytes]] = field(default_factory=dict) + aggregate_ndarrays: NDArrays = field(default_factory=list) + legacy_results: List[Tuple[ClientProxy, FitRes]] = field(default_factory=list) + + +class SecAggPlusWorkflow: + """The workflow for the SecAgg+ protocol. + + The SecAgg+ protocol ensures the secure summation of integer vectors owned by + multiple parties, without accessing any individual integer vector. This workflow + allows the server to compute the weighted average of model parameters across all + clients, ensuring individual contributions remain private. This is achieved by + clients sending both, a weighting factor and a weighted version of the locally + updated parameters, both of which are masked for privacy. Specifically, each + client uploads "[w, w * params]" with masks, where weighting factor 'w' is the + number of examples ('num_examples') and 'params' represents the model parameters + ('parameters') from the client's `FitRes`. The server then aggregates these + contributions to compute the weighted average of model parameters. + + The protocol involves four main stages: + - 'setup': Send SecAgg+ configuration to clients and collect their public keys. + - 'share keys': Broadcast public keys among clients and collect encrypted secret + key shares. + - 'collect masked vectors': Forward encrypted secret key shares to target clients + and collect masked model parameters. + - 'unmask': Collect secret key shares to decrypt and aggregate the model parameters. + + Only the aggregated model parameters are exposed and passed to + `Strategy.aggregate_fit`, ensuring individual data privacy. + + Parameters + ---------- + num_shares : Union[int, float] + The number of shares into which each client's private key is split under + the SecAgg+ protocol. If specified as a float, it represents the proportion + of all selected clients, and the number of shares will be set dynamically in + the run time. A private key can be reconstructed from these shares, allowing + for the secure aggregation of model updates. Each client sends one share to + each of its neighbors while retaining one. + reconstruction_threshold : Union[int, float] + The minimum number of shares required to reconstruct a client's private key, + or, if specified as a float, it represents the proportion of the total number + of shares needed for reconstruction. This threshold ensures privacy by allowing + for the recovery of contributions from dropped clients during aggregation, + without compromising individual client data. + max_weight : Optional[float] (default: 1000.0) + The maximum value of the weight that can be assigned to any single client's + update during the weighted average calculation on the server side, e.g., in the + FedAvg algorithm. + clipping_range : float, optional (default: 8.0) + The range within which model parameters are clipped before quantization. + This parameter ensures each model parameter is bounded within + [-clipping_range, clipping_range], facilitating quantization. + quantization_range : int, optional (default: 4194304, this equals 2**22) + The size of the range into which floating-point model parameters are quantized, + mapping each parameter to an integer in [0, quantization_range-1]. This + facilitates cryptographic operations on the model updates. + modulus_range : int, optional (default: 4294967296, this equals 2**32) + The range of values from which random mask entries are uniformly sampled + ([0, modulus_range-1]). `modulus_range` must be less than 4294967296. + Please use 2**n values for `modulus_range` to prevent overflow issues. + timeout : Optional[float] (default: None) + The timeout duration in seconds. If specified, the workflow will wait for + replies for this duration each time. If `None`, there is no time limit and + the workflow will wait until replies for all messages are received. + + Notes + ----- + - Generally, higher `num_shares` means more robust to dropouts while increasing the + computational costs; higher `reconstruction_threshold` means better privacy + guarantees but less tolerance to dropouts. + - Too large `max_weight` may compromise the precision of the quantization. + - `modulus_range` must be 2**n and larger than `quantization_range`. + - When `num_shares` is a float, it is interpreted as the proportion of all selected + clients, and hence the number of shares will be determined in the runtime. This + allows for dynamic adjustment based on the total number of participating clients. + - Similarly, when `reconstruction_threshold` is a float, it is interpreted as the + proportion of the number of shares needed for the reconstruction of a private key. + This feature enables flexibility in setting the security threshold relative to the + number of distributed shares. + - `num_shares`, `reconstruction_threshold`, and the quantization parameters + (`clipping_range`, `quantization_range`, `modulus_range`) play critical roles in + balancing privacy, robustness, and efficiency within the SecAgg+ protocol. + """ + + def __init__( # pylint: disable=R0913 + self, + num_shares: Union[int, float], + reconstruction_threshold: Union[int, float], + *, + max_weight: float = 1000.0, + clipping_range: float = 8.0, + quantization_range: int = 4194304, + modulus_range: int = 4294967296, + timeout: Optional[float] = None, + ) -> None: + self.num_shares = num_shares + self.reconstruction_threshold = reconstruction_threshold + self.max_weight = max_weight + self.clipping_range = clipping_range + self.quantization_range = quantization_range + self.modulus_range = modulus_range + self.timeout = timeout + + self._check_init_params() + + def __call__(self, driver: Driver, context: Context) -> None: + """Run the SecAgg+ protocol.""" + if not isinstance(context, LegacyContext): + raise TypeError( + f"Expect a LegacyContext, but get {type(context).__name__}." + ) + state = WorkflowState() + + steps = ( + self.setup_stage, + self.share_keys_stage, + self.collect_masked_vectors_stage, + self.unmask_stage, + ) + log(INFO, "Secure aggregation commencing.") + for step in steps: + if not step(driver, context, state): + log(INFO, "Secure aggregation halted.") + return + log(INFO, "Secure aggregation completed.") + + def _check_init_params(self) -> None: # pylint: disable=R0912 + # Check `num_shares` + if not isinstance(self.num_shares, (int, float)): + raise TypeError("`num_shares` must be of type int or float.") + if isinstance(self.num_shares, int): + if self.num_shares == 1: + self.num_shares = 1.0 + elif self.num_shares <= 2: + raise ValueError("`num_shares` as an integer must be greater than 2.") + elif self.num_shares > self.modulus_range / self.quantization_range: + log( + WARN, + "A `num_shares` larger than `modulus_range / quantization_range` " + "will potentially cause overflow when computing the aggregated " + "model parameters.", + ) + elif self.num_shares <= 0: + raise ValueError("`num_shares` as a float must be greater than 0.") + + # Check `reconstruction_threshold` + if not isinstance(self.reconstruction_threshold, (int, float)): + raise TypeError("`reconstruction_threshold` must be of type int or float.") + if isinstance(self.reconstruction_threshold, int): + if self.reconstruction_threshold == 1: + self.reconstruction_threshold = 1.0 + elif isinstance(self.num_shares, int): + if self.reconstruction_threshold >= self.num_shares: + raise ValueError( + "`reconstruction_threshold` must be less than `num_shares`." + ) + else: + if not 0 < self.reconstruction_threshold <= 1: + raise ValueError( + "If `reconstruction_threshold` is a float, " + "it must be greater than 0 and less than or equal to 1." + ) + + # Check `max_weight` + if self.max_weight <= 0: + raise ValueError("`max_weight` must be greater than 0.") + + # Check `quantization_range` + if self.quantization_range <= 0: + raise ValueError("`quantization_range` must be greater than 0.") + + # Check `quantization_range` + if not isinstance(self.quantization_range, int) or self.quantization_range <= 0: + raise ValueError( + "`quantization_range` must be an integer and greater than 0." + ) + + # Check `modulus_range` + if ( + not isinstance(self.modulus_range, int) + or self.modulus_range <= self.quantization_range + ): + raise ValueError( + "`modulus_range` must be an integer and " + "greater than `quantization_range`." + ) + if bin(self.modulus_range).count("1") != 1: + raise ValueError("`modulus_range` must be a power of 2.") + + def _check_threshold(self, state: WorkflowState) -> bool: + for node_id in state.sampled_node_ids: + active_neighbors = state.nid_to_neighbours[node_id] & state.active_node_ids + if len(active_neighbors) < state.threshold: + log(ERROR, "Insufficient available nodes.") + return False + return True + + def setup_stage( # pylint: disable=R0912, R0914, R0915 + self, driver: Driver, context: LegacyContext, state: WorkflowState + ) -> bool: + """Execute the 'setup' stage.""" + # Obtain fit instructions + cfg = context.state.configs_records[MAIN_CONFIGS_RECORD] + current_round = cast(int, cfg[WorkflowKey.CURRENT_ROUND]) + parameters = compat.parametersrecord_to_parameters( + context.state.parameters_records[MAIN_PARAMS_RECORD], + keep_input=True, + ) + proxy_fitins_lst = context.strategy.configure_fit( + current_round, parameters, context.client_manager + ) + if not proxy_fitins_lst: + log(INFO, "configure_fit: no clients selected, cancel") + return False + log( + INFO, + "configure_fit: strategy sampled %s clients (out of %s)", + len(proxy_fitins_lst), + context.client_manager.num_available(), + ) + + state.nid_to_fitins = { + proxy.node_id: compat.fitins_to_recordset(fitins, True) + for proxy, fitins in proxy_fitins_lst + } + state.nid_to_proxies = {proxy.node_id: proxy for proxy, _ in proxy_fitins_lst} + + # Protocol config + sampled_node_ids = list(state.nid_to_fitins.keys()) + num_samples = len(sampled_node_ids) + if num_samples < 2: + log(ERROR, "The number of samples should be greater than 1.") + return False + if isinstance(self.num_shares, float): + state.num_shares = round(self.num_shares * num_samples) + # If even + if state.num_shares < num_samples and state.num_shares & 1 == 0: + state.num_shares += 1 + # If too small + if state.num_shares <= 2: + state.num_shares = num_samples + else: + state.num_shares = self.num_shares + if isinstance(self.reconstruction_threshold, float): + state.threshold = round(self.reconstruction_threshold * state.num_shares) + # Avoid too small threshold + state.threshold = max(state.threshold, 2) + else: + state.threshold = self.reconstruction_threshold + state.active_node_ids = set(sampled_node_ids) + state.clipping_range = self.clipping_range + state.quantization_range = self.quantization_range + state.mod_range = self.modulus_range + state.max_weight = self.max_weight + sa_params_dict = { + Key.STAGE: Stage.SETUP, + Key.SAMPLE_NUMBER: num_samples, + Key.SHARE_NUMBER: state.num_shares, + Key.THRESHOLD: state.threshold, + Key.CLIPPING_RANGE: state.clipping_range, + Key.TARGET_RANGE: state.quantization_range, + Key.MOD_RANGE: state.mod_range, + Key.MAX_WEIGHT: state.max_weight, + } + + # The number of shares should better be odd in the SecAgg+ protocol. + if num_samples != state.num_shares and state.num_shares & 1 == 0: + log(WARN, "Number of shares in the SecAgg+ protocol should be odd.") + state.num_shares += 1 + + # Shuffle node IDs + random.shuffle(sampled_node_ids) + # Build neighbour relations (node ID -> secure IDs of neighbours) + half_share = state.num_shares >> 1 + state.nid_to_neighbours = { + nid: { + sampled_node_ids[(idx + offset) % num_samples] + for offset in range(-half_share, half_share + 1) + } + for idx, nid in enumerate(sampled_node_ids) + } + + state.sampled_node_ids = state.active_node_ids + + # Send setup configuration to clients + cfgs_record = ConfigsRecord(sa_params_dict) # type: ignore + content = RecordSet(configs_records={RECORD_KEY_CONFIGS: cfgs_record}) + + def make(nid: int) -> Message: + return driver.create_message( + content=content, + message_type=MessageType.TRAIN, + dst_node_id=nid, + group_id=str(cfg[WorkflowKey.CURRENT_ROUND]), + ttl="", + ) + + log( + DEBUG, + "[Stage 0] Sending configurations to %s clients.", + len(state.active_node_ids), + ) + msgs = driver.send_and_receive( + [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout + ) + state.active_node_ids = { + msg.metadata.src_node_id for msg in msgs if not msg.has_error() + } + log( + DEBUG, + "[Stage 0] Received public keys from %s clients.", + len(state.active_node_ids), + ) + + for msg in msgs: + if msg.has_error(): + continue + key_dict = msg.content.configs_records[RECORD_KEY_CONFIGS] + node_id = msg.metadata.src_node_id + pk1, pk2 = key_dict[Key.PUBLIC_KEY_1], key_dict[Key.PUBLIC_KEY_2] + state.nid_to_publickeys[node_id] = [cast(bytes, pk1), cast(bytes, pk2)] + + return self._check_threshold(state) + + def share_keys_stage( # pylint: disable=R0914 + self, driver: Driver, context: LegacyContext, state: WorkflowState + ) -> bool: + """Execute the 'share keys' stage.""" + cfg = context.state.configs_records[MAIN_CONFIGS_RECORD] + + def make(nid: int) -> Message: + neighbours = state.nid_to_neighbours[nid] & state.active_node_ids + cfgs_record = ConfigsRecord( + {str(nid): state.nid_to_publickeys[nid] for nid in neighbours} + ) + cfgs_record[Key.STAGE] = Stage.SHARE_KEYS + content = RecordSet(configs_records={RECORD_KEY_CONFIGS: cfgs_record}) + return driver.create_message( + content=content, + message_type=MessageType.TRAIN, + dst_node_id=nid, + group_id=str(cfg[WorkflowKey.CURRENT_ROUND]), + ttl="", + ) + + # Broadcast public keys to clients and receive secret key shares + log( + DEBUG, + "[Stage 1] Forwarding public keys to %s clients.", + len(state.active_node_ids), + ) + msgs = driver.send_and_receive( + [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout + ) + state.active_node_ids = { + msg.metadata.src_node_id for msg in msgs if not msg.has_error() + } + log( + DEBUG, + "[Stage 1] Received encrypted key shares from %s clients.", + len(state.active_node_ids), + ) + + # Build forward packet list dictionary + srcs: List[int] = [] + dsts: List[int] = [] + ciphertexts: List[bytes] = [] + fwd_ciphertexts: Dict[int, List[bytes]] = { + nid: [] for nid in state.active_node_ids + } # dest node ID -> list of ciphertexts + fwd_srcs: Dict[int, List[int]] = { + nid: [] for nid in state.active_node_ids + } # dest node ID -> list of src node IDs + for msg in msgs: + node_id = msg.metadata.src_node_id + res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS] + dst_lst = cast(List[int], res_dict[Key.DESTINATION_LIST]) + ctxt_lst = cast(List[bytes], res_dict[Key.CIPHERTEXT_LIST]) + srcs += [node_id] * len(dst_lst) + dsts += dst_lst + ciphertexts += ctxt_lst + + for src, dst, ciphertext in zip(srcs, dsts, ciphertexts): + if dst in fwd_ciphertexts: + fwd_ciphertexts[dst].append(ciphertext) + fwd_srcs[dst].append(src) + + state.forward_srcs = fwd_srcs + state.forward_ciphertexts = fwd_ciphertexts + + return self._check_threshold(state) + + def collect_masked_vectors_stage( + self, driver: Driver, context: LegacyContext, state: WorkflowState + ) -> bool: + """Execute the 'collect masked vectors' stage.""" + cfg = context.state.configs_records[MAIN_CONFIGS_RECORD] + + # Send secret key shares to clients (plus FitIns) and collect masked vectors + def make(nid: int) -> Message: + cfgs_dict = { + Key.STAGE: Stage.COLLECT_MASKED_VECTORS, + Key.CIPHERTEXT_LIST: state.forward_ciphertexts[nid], + Key.SOURCE_LIST: state.forward_srcs[nid], + } + cfgs_record = ConfigsRecord(cfgs_dict) # type: ignore + content = state.nid_to_fitins[nid] + content.configs_records[RECORD_KEY_CONFIGS] = cfgs_record + return driver.create_message( + content=content, + message_type=MessageType.TRAIN, + dst_node_id=nid, + group_id=str(cfg[WorkflowKey.CURRENT_ROUND]), + ttl="", + ) + + log( + DEBUG, + "[Stage 2] Forwarding encrypted key shares to %s clients.", + len(state.active_node_ids), + ) + msgs = driver.send_and_receive( + [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout + ) + state.active_node_ids = { + msg.metadata.src_node_id for msg in msgs if not msg.has_error() + } + log( + DEBUG, + "[Stage 2] Received masked vectors from %s clients.", + len(state.active_node_ids), + ) + + # Clear cache + del state.forward_ciphertexts, state.forward_srcs, state.nid_to_fitins + + # Sum collected masked vectors and compute active/dead node IDs + masked_vector = None + for msg in msgs: + res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS] + bytes_list = cast(List[bytes], res_dict[Key.MASKED_PARAMETERS]) + client_masked_vec = [bytes_to_ndarray(b) for b in bytes_list] + if masked_vector is None: + masked_vector = client_masked_vec + else: + masked_vector = parameters_addition(masked_vector, client_masked_vec) + if masked_vector is not None: + masked_vector = parameters_mod(masked_vector, state.mod_range) + state.aggregate_ndarrays = masked_vector + + # Backward compatibility with Strategy + for msg in msgs: + fitres = compat.recordset_to_fitres(msg.content, True) + proxy = state.nid_to_proxies[msg.metadata.src_node_id] + state.legacy_results.append((proxy, fitres)) + + return self._check_threshold(state) + + def unmask_stage( # pylint: disable=R0912, R0914, R0915 + self, driver: Driver, context: LegacyContext, state: WorkflowState + ) -> bool: + """Execute the 'unmask' stage.""" + cfg = context.state.configs_records[MAIN_CONFIGS_RECORD] + current_round = cast(int, cfg[WorkflowKey.CURRENT_ROUND]) + + # Construct active node IDs and dead node IDs + active_nids = state.active_node_ids + dead_nids = state.sampled_node_ids - active_nids + + # Send secure IDs of active and dead clients and collect key shares from clients + def make(nid: int) -> Message: + neighbours = state.nid_to_neighbours[nid] + cfgs_dict = { + Key.STAGE: Stage.UNMASK, + Key.ACTIVE_NODE_ID_LIST: list(neighbours & active_nids), + Key.DEAD_NODE_ID_LIST: list(neighbours & dead_nids), + } + cfgs_record = ConfigsRecord(cfgs_dict) # type: ignore + content = RecordSet(configs_records={RECORD_KEY_CONFIGS: cfgs_record}) + return driver.create_message( + content=content, + message_type=MessageType.TRAIN, + dst_node_id=nid, + group_id=str(current_round), + ttl="", + ) + + log( + DEBUG, + "[Stage 3] Requesting key shares from %s clients to remove masks.", + len(state.active_node_ids), + ) + msgs = driver.send_and_receive( + [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout + ) + state.active_node_ids = { + msg.metadata.src_node_id for msg in msgs if not msg.has_error() + } + log( + DEBUG, + "[Stage 3] Received key shares from %s clients.", + len(state.active_node_ids), + ) + + # Build collected shares dict + collected_shares_dict: Dict[int, List[bytes]] = {} + for nid in state.sampled_node_ids: + collected_shares_dict[nid] = [] + for msg in msgs: + res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS] + nids = cast(List[int], res_dict[Key.NODE_ID_LIST]) + shares = cast(List[bytes], res_dict[Key.SHARE_LIST]) + for owner_nid, share in zip(nids, shares): + collected_shares_dict[owner_nid].append(share) + + # Remove masks for every active client after collect_masked_vectors stage + masked_vector = state.aggregate_ndarrays + del state.aggregate_ndarrays + for nid, share_list in collected_shares_dict.items(): + if len(share_list) < state.threshold: + log( + ERROR, "Not enough shares to recover secret in unmask vectors stage" + ) + return False + secret = combine_shares(share_list) + if nid in active_nids: + # The seed for PRG is the private mask seed of an active client. + private_mask = pseudo_rand_gen( + secret, state.mod_range, get_parameters_shape(masked_vector) + ) + masked_vector = parameters_subtraction(masked_vector, private_mask) + else: + # The seed for PRG is the secret key 1 of a dropped client. + neighbours = state.nid_to_neighbours[nid] + neighbours.remove(nid) + + for neighbor_nid in neighbours: + shared_key = generate_shared_key( + bytes_to_private_key(secret), + bytes_to_public_key(state.nid_to_publickeys[neighbor_nid][0]), + ) + pairwise_mask = pseudo_rand_gen( + shared_key, state.mod_range, get_parameters_shape(masked_vector) + ) + if nid > neighbor_nid: + masked_vector = parameters_addition( + masked_vector, pairwise_mask + ) + else: + masked_vector = parameters_subtraction( + masked_vector, pairwise_mask + ) + recon_parameters = parameters_mod(masked_vector, state.mod_range) + q_total_ratio, recon_parameters = factor_extract(recon_parameters) + inv_dq_total_ratio = state.quantization_range / q_total_ratio + # recon_parameters = parameters_divide(recon_parameters, total_weights_factor) + aggregated_vector = dequantize( + recon_parameters, + state.clipping_range, + state.quantization_range, + ) + offset = -(len(active_nids) - 1) * state.clipping_range + for vec in aggregated_vector: + vec += offset + vec *= inv_dq_total_ratio + + # Backward compatibility with Strategy + results = state.legacy_results + parameters = ndarrays_to_parameters(aggregated_vector) + for _, fitres in results: + fitres.parameters = parameters + + # No exception/failure handling currently + log( + INFO, + "aggregate_fit: received %s results and %s failures", + len(results), + 0, + ) + aggregated_result = context.strategy.aggregate_fit(current_round, results, []) + parameters_aggregated, metrics_aggregated = aggregated_result + + # Update the parameters and write history + if parameters_aggregated: + paramsrecord = compat.parameters_to_parametersrecord( + parameters_aggregated, True + ) + context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord + context.history.add_metrics_distributed_fit( + server_round=current_round, metrics=metrics_aggregated + ) + return True diff --git a/src/py/flwr/simulation/__init__.py b/src/py/flwr/simulation/__init__.py index 724ea9273916..d36d9977d1c5 100644 --- a/src/py/flwr/simulation/__init__.py +++ b/src/py/flwr/simulation/__init__.py @@ -17,6 +17,8 @@ import importlib +from flwr.simulation.run_simulation import run_simulation, run_simulation_from_cli + is_ray_installed = importlib.util.find_spec("ray") is not None if is_ray_installed: @@ -34,6 +36,4 @@ def start_simulation(*args, **kwargs): # type: ignore raise ImportError(RAY_IMPORT_ERROR) -__all__ = [ - "start_simulation", -] +__all__ = ["start_simulation", "run_simulation_from_cli", "run_simulation"] diff --git a/src/py/flwr/simulation/app.py b/src/py/flwr/simulation/app.py index 0bb9290b6911..ff18f37664be 100644 --- a/src/py/flwr/simulation/app.py +++ b/src/py/flwr/simulation/app.py @@ -28,13 +28,13 @@ from flwr.client import ClientFn from flwr.common import EventType, event from flwr.common.logger import log -from flwr.server import Server -from flwr.server.app import ServerConfig, init_defaults, run_fl from flwr.server.client_manager import ClientManager from flwr.server.history import History +from flwr.server.server import Server, init_defaults, run_fl +from flwr.server.server_config import ServerConfig from flwr.server.strategy import Strategy from flwr.simulation.ray_transport.ray_actor import ( - DefaultActor, + ClientAppActor, VirtualClientEngineActor, VirtualClientEngineActorPool, pool_size_from_resources, @@ -82,7 +82,7 @@ def start_simulation( client_manager: Optional[ClientManager] = None, ray_init_args: Optional[Dict[str, Any]] = None, keep_initialised: Optional[bool] = False, - actor_type: Type[VirtualClientEngineActor] = DefaultActor, + actor_type: Type[VirtualClientEngineActor] = ClientAppActor, actor_kwargs: Optional[Dict[str, Any]] = None, actor_scheduling: Union[str, NodeAffinitySchedulingStrategy] = "DEFAULT", ) -> History: @@ -107,8 +107,8 @@ def start_simulation( List `client_id`s for each client. This is only required if `num_clients` is not set. Setting both `num_clients` and `clients_ids` with `len(clients_ids)` not equal to `num_clients` generates an error. - client_resources : Optional[Dict[str, float]] (default: `{"num_cpus": 1, - "num_gpus": 0.0}` CPU and GPU resources for a single client. Supported keys + client_resources : Optional[Dict[str, float]] (default: `{"num_cpus": 1, "num_gpus": 0.0}`) + CPU and GPU resources for a single client. Supported keys are `num_cpus` and `num_gpus`. To understand the GPU utilization caused by `num_gpus`, as well as using custom resources, please consult the Ray documentation. @@ -138,10 +138,10 @@ def start_simulation( keep_initialised: Optional[bool] (default: False) Set to True to prevent `ray.shutdown()` in case `ray.is_initialized()=True`. - actor_type: VirtualClientEngineActor (default: DefaultActor) + actor_type: VirtualClientEngineActor (default: ClientAppActor) Optionally specify the type of actor to use. The actor object, which persists throughout the simulation, will be the process in charge of - running the clients' jobs (i.e. their `fit()` method). + executing a ClientApp wrapping input argument `client_fn`. actor_kwargs: Optional[Dict[str, Any]] (default: None) If you want to create your own Actor classes, you might need to pass @@ -160,7 +160,7 @@ def start_simulation( ------- hist : flwr.server.history.History Object containing metrics from training. - """ + """ # noqa: E501 # pylint: disable-msg=too-many-locals event( EventType.START_SIMULATION_ENTER, @@ -219,7 +219,7 @@ def start_simulation( log( INFO, "Optimize your simulation with Flower VCE: " - "https://flower.dev/docs/framework/how-to-run-simulations.html", + "https://flower.ai/docs/framework/how-to-run-simulations.html", ) # Log the resources that a single client will be able to use @@ -314,18 +314,30 @@ def update_resources(f_stop: threading.Event) -> None: log(ERROR, traceback.format_exc()) log( ERROR, - "Your simulation crashed :(. This could be because of several reasons." + "Your simulation crashed :(. This could be because of several reasons. " "The most common are: " + "\n\t > Sometimes, issues in the simulation code itself can cause crashes. " + "It's always a good idea to double-check your code for any potential bugs " + "or inconsistencies that might be contributing to the problem. " + "For example: " + "\n\t\t - You might be using a class attribute in your clients that " + "hasn't been defined." + "\n\t\t - There could be an incorrect method call to a 3rd party library " + "(e.g., PyTorch)." + "\n\t\t - The return types of methods in your clients/strategies might be " + "incorrect." "\n\t > Your system couldn't fit a single VirtualClient: try lowering " "`client_resources`." "\n\t > All the actors in your pool crashed. This could be because: " "\n\t\t - You clients hit an out-of-memory (OOM) error and actors couldn't " "recover from it. Try launching your simulation with more generous " "`client_resources` setting (i.e. it seems %s is " - "not enough for your workload). Use fewer concurrent actors. " + "not enough for your run). Use fewer concurrent actors. " "\n\t\t - You were running a multi-node simulation and all worker nodes " "disconnected. The head node might still be alive but cannot accommodate " - "any actor with resources: %s.", + "any actor with resources: %s." + "\nTake a look at the Flower simulation examples for guidance " + ".", client_resources, client_resources, ) diff --git a/src/py/flwr/simulation/ray_transport/ray_actor.py b/src/py/flwr/simulation/ray_transport/ray_actor.py index 640817910396..08d0576e39f0 100644 --- a/src/py/flwr/simulation/ray_transport/ray_actor.py +++ b/src/py/flwr/simulation/ray_transport/ray_actor.py @@ -14,29 +14,22 @@ # ============================================================================== """Ray-based Flower Actor and ActorPool implementation.""" - +import asyncio import threading import traceback from abc import ABC -from logging import ERROR, WARNING +from logging import DEBUG, ERROR, WARNING from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union import ray from ray import ObjectRef from ray.util.actor_pool import ActorPool -from flwr import common -from flwr.client import Client, ClientFn -from flwr.client.workload_state import WorkloadState +from flwr.client.client_app import ClientApp, LoadClientAppError +from flwr.common import Context, Message from flwr.common.logger import log -from flwr.simulation.ray_transport.utils import check_clientfn_returns_client -# All possible returns by a client -ClientRes = Union[ - common.GetPropertiesRes, common.GetParametersRes, common.FitRes, common.EvaluateRes -] -# A function to be executed by a client to obtain some results -JobFn = Callable[[Client], ClientRes] +ClientAppFn = Callable[[], ClientApp] class ClientException(Exception): @@ -53,48 +46,49 @@ class VirtualClientEngineActor(ABC): def terminate(self) -> None: """Manually terminate Actor object.""" - log(WARNING, "Manually terminating %s}", self.__class__.__name__) + log(WARNING, "Manually terminating %s", self.__class__.__name__) ray.actor.exit_actor() def run( self, - client_fn: ClientFn, - job_fn: JobFn, + client_app_fn: ClientAppFn, + message: Message, cid: str, - state: WorkloadState, - ) -> Tuple[str, ClientRes, WorkloadState]: - """Run a client workload.""" - # Execute tasks and return result + context: Context, + ) -> Tuple[str, Message, Context]: + """Run a client run.""" + # Pass message through ClientApp and return a message # return also cid which is needed to ensure results # from the pool are correctly assigned to each ClientProxy try: - # Instantiate client (check 'Client' type is returned) - client = check_clientfn_returns_client(client_fn(cid)) - # Inject state - client.set_state(state) - # Run client job - job_results = job_fn(client) - # Retrieve state (potentially updated) - updated_state = client.get_state() + # Load app + app: ClientApp = client_app_fn() + + # Handle task message + out_message = app(message=message, context=context) + + except LoadClientAppError as load_ex: + raise load_ex + except Exception as ex: client_trace = traceback.format_exc() - message = ( - "\n\tSomething went wrong when running your client workload." + mssg = ( + "\n\tSomething went wrong when running your client run." "\n\tClient " + cid + " crashed when the " + self.__class__.__name__ - + " was running its workload." + + " was running its run." "\n\tException triggered on the client side: " + client_trace, ) - raise ClientException(str(message)) from ex + raise ClientException(str(mssg)) from ex - return cid, job_results, updated_state + return cid, out_message, context @ray.remote -class DefaultActor(VirtualClientEngineActor): - """A Ray Actor class that runs client workloads. +class ClientAppActor(VirtualClientEngineActor): + """A Ray Actor class that runs client runs. Parameters ---------- @@ -237,18 +231,16 @@ def add_actors_to_pool(self, num_actors: int) -> None: self._idle_actors.extend(new_actors) self.num_actors += num_actors - def submit( - self, fn: Any, value: Tuple[ClientFn, JobFn, str, WorkloadState] - ) -> None: - """Take idle actor and assign it a client workload. + def submit(self, fn: Any, value: Tuple[ClientAppFn, Message, str, Context]) -> None: + """Take an idle actor and assign it to run a client app and Message. Submit a job to an actor by first removing it from the list of idle actors, then - check if this actor was flagged to be removed from the pool + check if this actor was flagged to be removed from the pool. """ - client_fn, job_fn, cid, state = value + app_fn, mssg, cid, context = value actor = self._idle_actors.pop() if self._check_and_remove_actor_from_pool(actor): - future = fn(actor, client_fn, job_fn, cid, state) + future = fn(actor, app_fn, mssg, cid, context) future_key = tuple(future) if isinstance(future, List) else future self._future_to_actor[future_key] = (self._next_task_index, actor, cid) self._next_task_index += 1 @@ -257,7 +249,7 @@ def submit( self._cid_to_future[cid]["future"] = future_key def submit_client_job( - self, actor_fn: Any, job: Tuple[ClientFn, JobFn, str, WorkloadState] + self, actor_fn: Any, job: Tuple[ClientAppFn, Message, str, Context] ) -> None: """Submit a job while tracking client ids.""" _, _, cid, _ = job @@ -297,17 +289,17 @@ def _is_future_ready(self, cid: str) -> bool: return self._cid_to_future[cid]["ready"] # type: ignore - def _fetch_future_result(self, cid: str) -> Tuple[ClientRes, WorkloadState]: - """Fetch result and updated state for a VirtualClient from Object Store. + def _fetch_future_result(self, cid: str) -> Tuple[Message, Context]: + """Fetch result and updated context for a VirtualClient from Object Store. The job submitted by the ClientProxy interfacing with client with cid=cid is ready. Here we fetch it from the object store and return. """ try: future: ObjectRef[Any] = self._cid_to_future[cid]["future"] # type: ignore - res_cid, res, updated_state = ray.get( + res_cid, out_mssg, updated_context = ray.get( future - ) # type: (str, ClientRes, WorkloadState) + ) # type: (str, Message, Context) except ray.exceptions.RayActorError as ex: log(ERROR, ex) if hasattr(ex, "actor_id"): @@ -324,7 +316,7 @@ def _fetch_future_result(self, cid: str) -> Tuple[ClientRes, WorkloadState]: # Reset mapping self._reset_cid_to_future_dict(cid) - return res, updated_state + return out_mssg, updated_context def _flag_actor_for_removal(self, actor_id_hex: str) -> None: """Flag actor that should be removed from pool.""" @@ -411,7 +403,7 @@ def process_unordered_future(self, timeout: Optional[float] = None) -> None: def get_client_result( self, cid: str, timeout: Optional[float] - ) -> Tuple[ClientRes, WorkloadState]: + ) -> Tuple[Message, Context]: """Get result from VirtualClient with specific cid.""" # Loop until all jobs submitted to the pool are completed. Break early # if the result for the ClientProxy calling this method is ready @@ -423,5 +415,92 @@ def get_client_result( break # Fetch result belonging to the VirtualClient calling this method - # Return both result from tasks and (potentially) updated workload state + # Return both result from tasks and (potentially) updated run context return self._fetch_future_result(cid) + + +def init_ray(*args: Any, **kwargs: Any) -> None: + """Intialises Ray if not already initialised.""" + if not ray.is_initialized(): + ray.init(*args, **kwargs) + + +class BasicActorPool: + """A basic actor pool.""" + + def __init__( + self, + actor_type: Type[VirtualClientEngineActor], + client_resources: Dict[str, Union[int, float]], + actor_kwargs: Dict[str, Any], + ): + self.client_resources = client_resources + + # Queue of idle actors + self.pool: "asyncio.Queue[Type[VirtualClientEngineActor]]" = asyncio.Queue( + maxsize=1024 + ) + self.num_actors = 0 + + # Resolve arguments to pass during actor init + actor_args = {} if actor_kwargs is None else actor_kwargs + + # A function that creates an actor + self.create_actor_fn = lambda: actor_type.options( # type: ignore + **client_resources + ).remote(**actor_args) + + # Figure out how many actors can be created given the cluster resources + # and the resources the user indicates each VirtualClient will need + self.actors_capacity = pool_size_from_resources(client_resources) + self._future_to_actor: Dict[Any, Type[VirtualClientEngineActor]] = {} + + def is_actor_available(self) -> bool: + """Return true if there is an idle actor.""" + return self.pool.qsize() > 0 + + async def add_actors_to_pool(self, num_actors: int) -> None: + """Add actors to the pool. + + This method may be executed also if new resources are added to your Ray cluster + (e.g. you add a new node). + """ + for _ in range(num_actors): + await self.pool.put(self.create_actor_fn()) # type: ignore + self.num_actors += num_actors + + async def terminate_all_actors(self) -> None: + """Terminate actors in pool.""" + num_terminated = 0 + while self.pool.qsize(): + actor = await self.pool.get() + actor.terminate.remote() # type: ignore + num_terminated += 1 + + log(DEBUG, "Terminated %i actors", num_terminated) + + async def submit( + self, actor_fn: Any, job: Tuple[ClientAppFn, Message, str, Context] + ) -> Any: + """On idle actor, submit job and return future.""" + # Remove idle actor from pool + actor = await self.pool.get() + # Submit job to actor + app_fn, mssg, cid, context = job + future = actor_fn(actor, app_fn, mssg, cid, context) + # Keep track of future:actor (so we can fetch the actor upon job completion + # and add it back to the pool) + self._future_to_actor[future] = actor + return future + + async def fetch_result_and_return_actor_to_pool( + self, future: Any + ) -> Tuple[Message, Context]: + """Pull result given a future and add actor back to pool.""" + # Get actor that ran job + actor = self._future_to_actor.pop(future) + await self.pool.put(actor) + # Retrieve result for object store + # Instead of doing ray.get(future) we await it + _, out_mssg, updated_context = await future + return out_mssg, updated_context diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py index c6a63298dae6..c3493163ac52 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py @@ -17,107 +17,27 @@ import traceback from logging import ERROR -from typing import Dict, Optional, cast - -import ray +from typing import Optional from flwr import common -from flwr.client import Client, ClientFn -from flwr.client.client import ( - maybe_call_evaluate, - maybe_call_fit, - maybe_call_get_parameters, - maybe_call_get_properties, -) +from flwr.client import ClientFn +from flwr.client.client_app import ClientApp from flwr.client.node_state import NodeState +from flwr.common import Message, Metadata, RecordSet +from flwr.common.constant import MessageType, MessageTypeLegacy from flwr.common.logger import log -from flwr.server.client_proxy import ClientProxy -from flwr.simulation.ray_transport.ray_actor import ( - ClientRes, - JobFn, - VirtualClientEngineActorPool, +from flwr.common.recordset_compat import ( + evaluateins_to_recordset, + fitins_to_recordset, + getparametersins_to_recordset, + getpropertiesins_to_recordset, + recordset_to_evaluateres, + recordset_to_fitres, + recordset_to_getparametersres, + recordset_to_getpropertiesres, ) - - -class RayClientProxy(ClientProxy): - """Flower client proxy which delegates work using Ray.""" - - def __init__(self, client_fn: ClientFn, cid: str, resources: Dict[str, float]): - super().__init__(cid) - self.client_fn = client_fn - self.resources = resources - - def get_properties( - self, ins: common.GetPropertiesIns, timeout: Optional[float] - ) -> common.GetPropertiesRes: - """Return client's properties.""" - future_get_properties_res = launch_and_get_properties.options( # type: ignore - **self.resources, - ).remote(self.client_fn, self.cid, ins) - try: - res = ray.get(future_get_properties_res, timeout=timeout) - except Exception as ex: - log(ERROR, ex) - raise ex - return cast( - common.GetPropertiesRes, - res, - ) - - def get_parameters( - self, ins: common.GetParametersIns, timeout: Optional[float] - ) -> common.GetParametersRes: - """Return the current local model parameters.""" - future_paramseters_res = launch_and_get_parameters.options( # type: ignore - **self.resources, - ).remote(self.client_fn, self.cid, ins) - try: - res = ray.get(future_paramseters_res, timeout=timeout) - except Exception as ex: - log(ERROR, ex) - raise ex - return cast( - common.GetParametersRes, - res, - ) - - def fit(self, ins: common.FitIns, timeout: Optional[float]) -> common.FitRes: - """Train model parameters on the locally held dataset.""" - future_fit_res = launch_and_fit.options( # type: ignore - **self.resources, - ).remote(self.client_fn, self.cid, ins) - try: - res = ray.get(future_fit_res, timeout=timeout) - except Exception as ex: - log(ERROR, ex) - raise ex - return cast( - common.FitRes, - res, - ) - - def evaluate( - self, ins: common.EvaluateIns, timeout: Optional[float] - ) -> common.EvaluateRes: - """Evaluate model parameters on the locally held dataset.""" - future_evaluate_res = launch_and_evaluate.options( # type: ignore - **self.resources, - ).remote(self.client_fn, self.cid, ins) - try: - res = ray.get(future_evaluate_res, timeout=timeout) - except Exception as ex: - log(ERROR, ex) - raise ex - return cast( - common.EvaluateRes, - res, - ) - - def reconnect( - self, ins: common.ReconnectIns, timeout: Optional[float] - ) -> common.DisconnectRes: - """Disconnect and (optionally) reconnect later.""" - return common.DisconnectRes(reason="") # Nothing to do here (yet) +from flwr.server.client_proxy import ClientProxy +from flwr.simulation.ray_transport.ray_actor import VirtualClientEngineActorPool class RayActorClientProxy(ClientProxy): @@ -127,171 +47,149 @@ def __init__( self, client_fn: ClientFn, cid: str, actor_pool: VirtualClientEngineActorPool ): super().__init__(cid) - self.client_fn = client_fn + + def _load_app() -> ClientApp: + return ClientApp(client_fn=client_fn) + + self.app_fn = _load_app self.actor_pool = actor_pool self.proxy_state = NodeState() - def _submit_job(self, job_fn: JobFn, timeout: Optional[float]) -> ClientRes: - # The VCE is not exposed to TaskIns, it won't handle multilple workloads - # For the time being, fixing workload_id is a small compromise - # This will be one of the first points to address integrating VCE + DriverAPI - workload_id = 0 + def _submit_job(self, message: Message, timeout: Optional[float]) -> Message: + """Sumbit a message to the ActorPool.""" + run_id = message.metadata.run_id # Register state - self.proxy_state.register_workloadstate(workload_id=workload_id) + self.proxy_state.register_context(run_id=run_id) # Retrieve state - state = self.proxy_state.retrieve_workloadstate(workload_id=workload_id) + state = self.proxy_state.retrieve_context(run_id=run_id) try: self.actor_pool.submit_client_job( - lambda a, c_fn, j_fn, cid, state: a.run.remote(c_fn, j_fn, cid, state), - (self.client_fn, job_fn, self.cid, state), + lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state), + (self.app_fn, message, self.cid, state), + ) + out_mssg, updated_context = self.actor_pool.get_client_result( + self.cid, timeout ) - res, updated_state = self.actor_pool.get_client_result(self.cid, timeout) # Update state - self.proxy_state.update_workloadstate( - workload_id=workload_id, workload_state=updated_state - ) + self.proxy_state.update_context(run_id=run_id, context=updated_context) except Exception as ex: if self.actor_pool.num_actors == 0: # At this point we want to stop the simulation. - # since no more client workloads will be executed + # since no more client runs will be executed log(ERROR, "ActorPool is empty!!!") log(ERROR, traceback.format_exc()) log(ERROR, ex) raise ex - return res + return out_mssg + + def _wrap_recordset_in_message( + self, + recordset: RecordSet, + message_type: str, + timeout: Optional[float], + group_id: Optional[int], + ) -> Message: + """Wrap a RecordSet inside a Message.""" + return Message( + content=recordset, + metadata=Metadata( + run_id=0, + message_id="", + group_id=str(group_id) if group_id is not None else "", + src_node_id=0, + dst_node_id=int(self.cid), + reply_to_message="", + ttl=str(timeout) if timeout else "", + message_type=message_type, + partition_id=int(self.cid), + ), + ) def get_properties( - self, ins: common.GetPropertiesIns, timeout: Optional[float] + self, + ins: common.GetPropertiesIns, + timeout: Optional[float], + group_id: Optional[int], ) -> common.GetPropertiesRes: """Return client's properties.""" + recordset = getpropertiesins_to_recordset(ins) + message = self._wrap_recordset_in_message( + recordset, + message_type=MessageTypeLegacy.GET_PROPERTIES, + timeout=timeout, + group_id=group_id, + ) - def get_properties(client: Client) -> common.GetPropertiesRes: - return maybe_call_get_properties( - client=client, - get_properties_ins=ins, - ) - - res = self._submit_job(get_properties, timeout) + message_out = self._submit_job(message, timeout) - return cast( - common.GetPropertiesRes, - res, - ) + return recordset_to_getpropertiesres(message_out.content) def get_parameters( - self, ins: common.GetParametersIns, timeout: Optional[float] + self, + ins: common.GetParametersIns, + timeout: Optional[float], + group_id: Optional[int], ) -> common.GetParametersRes: """Return the current local model parameters.""" + recordset = getparametersins_to_recordset(ins) + message = self._wrap_recordset_in_message( + recordset, + message_type=MessageTypeLegacy.GET_PARAMETERS, + timeout=timeout, + group_id=group_id, + ) - def get_parameters(client: Client) -> common.GetParametersRes: - return maybe_call_get_parameters( - client=client, - get_parameters_ins=ins, - ) + message_out = self._submit_job(message, timeout) - res = self._submit_job(get_parameters, timeout) + return recordset_to_getparametersres(message_out.content, keep_input=False) - return cast( - common.GetParametersRes, - res, - ) - - def fit(self, ins: common.FitIns, timeout: Optional[float]) -> common.FitRes: + def fit( + self, ins: common.FitIns, timeout: Optional[float], group_id: Optional[int] + ) -> common.FitRes: """Train model parameters on the locally held dataset.""" + recordset = fitins_to_recordset( + ins, keep_input=True + ) # This must stay TRUE since ins are in-memory + message = self._wrap_recordset_in_message( + recordset, + message_type=MessageType.TRAIN, + timeout=timeout, + group_id=group_id, + ) - def fit(client: Client) -> common.FitRes: - return maybe_call_fit( - client=client, - fit_ins=ins, - ) - - res = self._submit_job(fit, timeout) + message_out = self._submit_job(message, timeout) - return cast( - common.FitRes, - res, - ) + return recordset_to_fitres(message_out.content, keep_input=False) def evaluate( - self, ins: common.EvaluateIns, timeout: Optional[float] + self, ins: common.EvaluateIns, timeout: Optional[float], group_id: Optional[int] ) -> common.EvaluateRes: """Evaluate model parameters on the locally held dataset.""" + recordset = evaluateins_to_recordset( + ins, keep_input=True + ) # This must stay TRUE since ins are in-memory + message = self._wrap_recordset_in_message( + recordset, + message_type=MessageType.EVALUATE, + timeout=timeout, + group_id=group_id, + ) - def evaluate(client: Client) -> common.EvaluateRes: - return maybe_call_evaluate( - client=client, - evaluate_ins=ins, - ) + message_out = self._submit_job(message, timeout) - res = self._submit_job(evaluate, timeout) - - return cast( - common.EvaluateRes, - res, - ) + return recordset_to_evaluateres(message_out.content) def reconnect( - self, ins: common.ReconnectIns, timeout: Optional[float] + self, + ins: common.ReconnectIns, + timeout: Optional[float], + group_id: Optional[int], ) -> common.DisconnectRes: """Disconnect and (optionally) reconnect later.""" return common.DisconnectRes(reason="") # Nothing to do here (yet) - - -@ray.remote -def launch_and_get_properties( - client_fn: ClientFn, cid: str, get_properties_ins: common.GetPropertiesIns -) -> common.GetPropertiesRes: - """Exectue get_properties remotely.""" - client: Client = _create_client(client_fn, cid) - return maybe_call_get_properties( - client=client, - get_properties_ins=get_properties_ins, - ) - - -@ray.remote -def launch_and_get_parameters( - client_fn: ClientFn, cid: str, get_parameters_ins: common.GetParametersIns -) -> common.GetParametersRes: - """Exectue get_parameters remotely.""" - client: Client = _create_client(client_fn, cid) - return maybe_call_get_parameters( - client=client, - get_parameters_ins=get_parameters_ins, - ) - - -@ray.remote -def launch_and_fit( - client_fn: ClientFn, cid: str, fit_ins: common.FitIns -) -> common.FitRes: - """Exectue fit remotely.""" - client: Client = _create_client(client_fn, cid) - return maybe_call_fit( - client=client, - fit_ins=fit_ins, - ) - - -@ray.remote -def launch_and_evaluate( - client_fn: ClientFn, cid: str, evaluate_ins: common.EvaluateIns -) -> common.EvaluateRes: - """Exectue evaluate remotely.""" - client: Client = _create_client(client_fn, cid) - return maybe_call_evaluate( - client=client, - evaluate_ins=evaluate_ins, - ) - - -def _create_client(client_fn: ClientFn, cid: str) -> Client: - """Create a client instance.""" - # Materialize client - return client_fn(cid) diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py index b87418b671d3..22c5425cd9fd 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py @@ -17,17 +17,29 @@ from math import pi from random import shuffle -from typing import List, Tuple, Type, cast +from typing import Dict, List, Tuple, Type import ray from flwr.client import Client, NumPyClient -from flwr.client.workload_state import WorkloadState -from flwr.common import Code, GetPropertiesRes, Status +from flwr.client.client_app import ClientApp +from flwr.common import ( + Config, + ConfigsRecord, + Context, + Message, + MessageTypeLegacy, + Metadata, + RecordSet, + Scalar, +) +from flwr.common.recordset_compat import ( + getpropertiesins_to_recordset, + recordset_to_getpropertiesres, +) +from flwr.common.recordset_compat_test import _get_valid_getpropertiesins from flwr.simulation.ray_transport.ray_actor import ( - ClientRes, - DefaultActor, - JobFn, + ClientAppActor, VirtualClientEngineActor, VirtualClientEngineActorPool, ) @@ -40,32 +52,24 @@ class DummyClient(NumPyClient): def __init__(self, cid: str) -> None: self.cid = int(cid) + def get_properties(self, config: Config) -> Dict[str, Scalar]: + """Return properties by doing a simple calculation.""" + result = int(self.cid) * pi + + # store something in context + self.context.state.configs_records["result"] = ConfigsRecord( + {"result": str(result)} + ) + return {"result": result} + def get_dummy_client(cid: str) -> Client: """Return a DummyClient converted to Client type.""" return DummyClient(cid).to_client() -# A dummy workload -def job_fn(cid: str) -> JobFn: # pragma: no cover - """Construct a simple job with cid dependency.""" - - def cid_times_pi(client: Client) -> ClientRes: # pylint: disable=unused-argument - result = int(cid) * pi - - # store something in state - client.numpy_client.state.state["result"] = str(result) # type: ignore - - # now let's convert it to a GetPropertiesRes response - return GetPropertiesRes( - status=Status(Code(0), message="test"), properties={"result": result} - ) - - return cid_times_pi - - def prep( - actor_type: Type[VirtualClientEngineActor] = DefaultActor, + actor_type: Type[VirtualClientEngineActor] = ClientAppActor, ) -> Tuple[List[RayActorClientProxy], VirtualClientEngineActorPool]: # pragma: no cover """Prepare ClientProxies and pool for tests.""" client_resources = {"num_cpus": 1, "num_gpus": 0.0} @@ -100,51 +104,75 @@ def test_cid_consistency_one_at_a_time() -> None: Submit one job and waits for completion. Then submits the next and so on """ proxies, _ = prep() + + getproperties_ins = _get_valid_getpropertiesins() + recordset = getpropertiesins_to_recordset(getproperties_ins) + # submit jobs one at a time for prox in proxies: - res = prox._submit_job( # pylint: disable=protected-access - job_fn=job_fn(prox.cid), timeout=None + message = prox._wrap_recordset_in_message( # pylint: disable=protected-access + recordset, + MessageTypeLegacy.GET_PROPERTIES, + timeout=None, + group_id=0, + ) + message_out = prox._submit_job( # pylint: disable=protected-access + message=message, timeout=None ) - res = cast(GetPropertiesRes, res) + res = recordset_to_getpropertiesres(message_out.content) + assert int(prox.cid) * pi == res.properties["result"] ray.shutdown() -def test_cid_consistency_all_submit_first_workload_consistency() -> None: +def test_cid_consistency_all_submit_first_run_consistency() -> None: """Test that ClientProxies get the result of client job they submit. All jobs are submitted at the same time. Then fetched one at a time. This also tests - NodeState (at each Proxy) and WorkloadState basic functionality. + NodeState (at each Proxy) and RunState basic functionality. """ proxies, _ = prep() - workload_id = 0 + run_id = 0 + + getproperties_ins = _get_valid_getpropertiesins() + recordset = getpropertiesins_to_recordset(getproperties_ins) # submit all jobs (collect later) shuffle(proxies) for prox in proxies: # Register state - prox.proxy_state.register_workloadstate(workload_id=workload_id) + prox.proxy_state.register_context(run_id=run_id) # Retrieve state - state = prox.proxy_state.retrieve_workloadstate(workload_id=workload_id) + state = prox.proxy_state.retrieve_context(run_id=run_id) - job = job_fn(prox.cid) + message = prox._wrap_recordset_in_message( # pylint: disable=protected-access + recordset, + message_type=MessageTypeLegacy.GET_PROPERTIES, + timeout=None, + group_id=0, + ) prox.actor_pool.submit_client_job( - lambda a, c_fn, j_fn, cid, state: a.run.remote(c_fn, j_fn, cid, state), - (prox.client_fn, job, prox.cid, state), + lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state), + (prox.app_fn, message, prox.cid, state), ) # fetch results one at a time shuffle(proxies) for prox in proxies: - res, updated_state = prox.actor_pool.get_client_result(prox.cid, timeout=None) - prox.proxy_state.update_workloadstate(workload_id, workload_state=updated_state) - res = cast(GetPropertiesRes, res) + message_out, updated_context = prox.actor_pool.get_client_result( + prox.cid, timeout=None + ) + prox.proxy_state.update_context(run_id, context=updated_context) + res = recordset_to_getpropertiesres(message_out.content) + assert int(prox.cid) * pi == res.properties["result"] assert ( str(int(prox.cid) * pi) - == prox.proxy_state.retrieve_workloadstate(workload_id).state["result"] + == prox.proxy_state.retrieve_context(run_id).state.configs_records[ + "result" + ]["result"] ) ray.shutdown() @@ -156,20 +184,39 @@ def test_cid_consistency_without_proxies() -> None: num_clients = len(proxies) cids = [str(cid) for cid in range(num_clients)] + getproperties_ins = _get_valid_getpropertiesins() + recordset = getpropertiesins_to_recordset(getproperties_ins) + + def _load_app() -> ClientApp: + return ClientApp(client_fn=get_dummy_client) + # submit all jobs (collect later) shuffle(cids) for cid in cids: - job = job_fn(cid) + message = Message( + content=recordset, + metadata=Metadata( + run_id=0, + message_id="", + group_id=str(0), + src_node_id=0, + dst_node_id=12345, + reply_to_message="", + ttl="", + message_type=MessageTypeLegacy.GET_PROPERTIES, + partition_id=int(cid), + ), + ) pool.submit_client_job( lambda a, c_fn, j_fn, cid_, state: a.run.remote(c_fn, j_fn, cid_, state), - (get_dummy_client, job, cid, WorkloadState(state={})), + (_load_app, message, cid, Context(state=RecordSet())), ) # fetch results one at a time shuffle(cids) for cid in cids: - res, _ = pool.get_client_result(cid, timeout=None) - res = cast(GetPropertiesRes, res) + message_out, _ = pool.get_client_result(cid, timeout=None) + res = recordset_to_getpropertiesres(message_out.content) assert int(cid) * pi == res.properties["result"] ray.shutdown() diff --git a/src/py/flwr/simulation/ray_transport/utils.py b/src/py/flwr/simulation/ray_transport/utils.py index c8e6aa6cbe21..3861164998a4 100644 --- a/src/py/flwr/simulation/ray_transport/utils.py +++ b/src/py/flwr/simulation/ray_transport/utils.py @@ -15,9 +15,9 @@ """Utilities for Actors in the Virtual Client Engine.""" import traceback +import warnings from logging import ERROR -from flwr.client import Client from flwr.common.logger import log try: @@ -26,7 +26,7 @@ TF = None # Display Deprecation warning once -# warnings.filterwarnings("once", category=DeprecationWarning) +warnings.filterwarnings("once", category=DeprecationWarning) def enable_tf_gpu_growth() -> None: @@ -37,7 +37,7 @@ def enable_tf_gpu_growth() -> None: # the same GPU. # Luckily we can disable this behavior by enabling memory growth # on the GPU. In this way, VRAM allocated to the processes grows based - # on the needs for the workload. (this is for instance the default + # on the needs for the run. (this is for instance the default # behavior in PyTorch) # While this behavior is critical for Actors, you'll likely need it # as well in your main process (where the server runs and might evaluate @@ -59,25 +59,3 @@ def enable_tf_gpu_growth() -> None: log(ERROR, traceback.format_exc()) log(ERROR, ex) raise ex - - -def check_clientfn_returns_client(client: Client) -> Client: - """Warn once that clients returned in `clinet_fn` should be of type Client. - - This is here for backwards compatibility. If a ClientFn is provided returning - a different type of client (e.g. NumPyClient) we'll warn the user but convert - the client internally to `Client` by calling `.to_client()`. - """ - if not isinstance(client, Client): - # mssg = ( - # " Ensure your client is of type `Client`. Please convert it" - # " using the `.to_client()` method before returning it" - # " in the `client_fn` you pass to `start_simulation`." - # " We have applied this conversion on your behalf." - # " Not returning a `Client` might trigger an error in future" - # " versions of Flower." - # ) - - # warnings.warn(mssg, DeprecationWarning, stacklevel=2) - client = client.to_client() - return client diff --git a/src/py/flwr/simulation/run_simulation.py b/src/py/flwr/simulation/run_simulation.py new file mode 100644 index 000000000000..56fce363726a --- /dev/null +++ b/src/py/flwr/simulation/run_simulation.py @@ -0,0 +1,441 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower Simulation.""" + +import argparse +import asyncio +import json +import logging +import threading +import traceback +from logging import DEBUG, ERROR, INFO, WARNING +from time import sleep +from typing import Dict, Optional + +import grpc + +from flwr.client import ClientApp +from flwr.common import EventType, event, log +from flwr.common.typing import ConfigsRecordValues +from flwr.server.driver.driver import Driver +from flwr.server.run_serverapp import run +from flwr.server.server_app import ServerApp +from flwr.server.superlink.driver.driver_grpc import run_driver_api_grpc +from flwr.server.superlink.fleet import vce +from flwr.server.superlink.state import StateFactory +from flwr.simulation.ray_transport.utils import ( + enable_tf_gpu_growth as enable_gpu_growth, +) + + +# Entry point from CLI +def run_simulation_from_cli() -> None: + """Run Simulation Engine from the CLI.""" + args = _parse_args_run_simulation().parse_args() + + # Load JSON config + backend_config_dict = json.loads(args.backend_config) + + _run_simulation( + server_app_attr=args.server_app, + client_app_attr=args.client_app, + num_supernodes=args.num_supernodes, + backend_name=args.backend, + backend_config=backend_config_dict, + app_dir=args.app_dir, + driver_api_address=args.driver_api_address, + enable_tf_gpu_growth=args.enable_tf_gpu_growth, + verbose_logging=args.verbose, + ) + + +# Entry point from Python session (script or notebook) +# pylint: disable=too-many-arguments +def run_simulation( + server_app: ServerApp, + client_app: ClientApp, + num_supernodes: int, + backend_name: str = "ray", + backend_config: Optional[Dict[str, ConfigsRecordValues]] = None, + enable_tf_gpu_growth: bool = False, + verbose_logging: bool = False, +) -> None: + r"""Run a Flower App using the Simulation Engine. + + Parameters + ---------- + server_app : ServerApp + The `ServerApp` to be executed. It will send messages to different `ClientApp` + instances running on different (virtual) SuperNodes. + + client_app : ClientApp + The `ClientApp` to be executed by each of the SuperNodes. It will receive + messages sent by the `ServerApp`. + + num_supernodes : int + Number of nodes that run a ClientApp. They can be sampled by a + Driver in the ServerApp and receive a Message describing what the ClientApp + should perform. + + backend_name : str (default: ray) + A simulation backend that runs `ClientApp`s. + + backend_config : Optional[Dict[str, ConfigsRecordValues]] + 'A dictionary, e.g {"": , "": } to configure a + backend. Values supported in are those included by + `flwr.common.typing.ConfigsRecordValues`. + + enable_tf_gpu_growth : bool (default: False) + A boolean to indicate whether to enable GPU growth on the main thread. This is + desirable if you make use of a TensorFlow model on your `ServerApp` while + having your `ClientApp` running on the same GPU. Without enabling this, you + might encounter an out-of-memory error because TensorFlow, by default, allocates + all GPU memory. Read more about how `tf.config.experimental.set_memory_growth()` + works in the TensorFlow documentation: https://www.tensorflow.org/api/stable. + + verbose_logging : bool (default: False) + When diabled, only INFO, WARNING and ERROR log messages will be shown. If + enabled, DEBUG-level logs will be displayed. + """ + _run_simulation( + num_supernodes=num_supernodes, + client_app=client_app, + server_app=server_app, + backend_name=backend_name, + backend_config=backend_config, + enable_tf_gpu_growth=enable_tf_gpu_growth, + verbose_logging=verbose_logging, + ) + + +# pylint: disable=too-many-arguments +def run_serverapp_th( + server_app_attr: Optional[str], + server_app: Optional[ServerApp], + driver: Driver, + app_dir: str, + f_stop: asyncio.Event, + enable_tf_gpu_growth: bool, + delay_launch: int = 3, +) -> threading.Thread: + """Run SeverApp in a thread.""" + + def server_th_with_start_checks( # type: ignore + tf_gpu_growth: bool, stop_event: asyncio.Event, **kwargs + ) -> None: + """Run SeverApp, after check if GPU memory grouwth has to be set. + + Upon exception, trigger stop event for Simulation Engine. + """ + try: + if tf_gpu_growth: + log(INFO, "Enabling GPU growth for Tensorflow on the main thread.") + enable_gpu_growth() + + # Run ServerApp + run(**kwargs) + except Exception as ex: # pylint: disable=broad-exception-caught + log(ERROR, "ServerApp thread raised an exception: %s", ex) + log(ERROR, traceback.format_exc()) + finally: + log(DEBUG, "ServerApp finished running.") + # Upon completion, trigger stop event if one was passed + if stop_event is not None: + stop_event.set() + log(WARNING, "Triggered stop event for Simulation Engine.") + + serverapp_th = threading.Thread( + target=server_th_with_start_checks, + args=(enable_tf_gpu_growth, f_stop), + kwargs={ + "server_app_attr": server_app_attr, + "loaded_server_app": server_app, + "driver": driver, + "server_app_dir": app_dir, + }, + ) + sleep(delay_launch) + serverapp_th.start() + return serverapp_th + + +# pylint: disable=too-many-locals +def _main_loop( + num_supernodes: int, + backend_name: str, + backend_config_stream: str, + driver_api_address: str, + app_dir: str, + enable_tf_gpu_growth: bool, + client_app: Optional[ClientApp] = None, + client_app_attr: Optional[str] = None, + server_app: Optional[ServerApp] = None, + server_app_attr: Optional[str] = None, +) -> None: + """Launch SuperLink with Simulation Engine, then ServerApp on a separate thread. + + Everything runs on the main thread or a separate one, depening on whether the main + thread already contains a running Asyncio event loop. This is the case if running + the Simulation Engine on a Jupyter/Colab notebook. + """ + # Initialize StateFactory + state_factory = StateFactory(":flwr-in-memory-state:") + + # Start Driver API + driver_server: grpc.Server = run_driver_api_grpc( + address=driver_api_address, + state_factory=state_factory, + certificates=None, + ) + + f_stop = asyncio.Event() + serverapp_th = None + try: + # Initialize Driver + driver = Driver( + driver_service_address=driver_api_address, + root_certificates=None, + ) + + # Get and run ServerApp thread + serverapp_th = run_serverapp_th( + server_app_attr=server_app_attr, + server_app=server_app, + driver=driver, + app_dir=app_dir, + f_stop=f_stop, + enable_tf_gpu_growth=enable_tf_gpu_growth, + ) + + # SuperLink with Simulation Engine + event(EventType.RUN_SUPERLINK_ENTER) + vce.start_vce( + num_supernodes=num_supernodes, + client_app_attr=client_app_attr, + client_app=client_app, + backend_name=backend_name, + backend_config_json_stream=backend_config_stream, + app_dir=app_dir, + state_factory=state_factory, + f_stop=f_stop, + ) + + except Exception as ex: + log(ERROR, "An exception occurred !! %s", ex) + log(ERROR, traceback.format_exc()) + raise RuntimeError("An error was encountered. Ending simulation.") from ex + + finally: + # Stop Driver + driver_server.stop(grace=0) + driver.close() + # Trigger stop event + f_stop.set() + + event(EventType.RUN_SUPERLINK_LEAVE) + if serverapp_th: + serverapp_th.join() + + log(INFO, "Stopping Simulation Engine now.") + + +# pylint: disable=too-many-arguments,too-many-locals +def _run_simulation( + num_supernodes: int, + client_app: Optional[ClientApp] = None, + server_app: Optional[ServerApp] = None, + backend_name: str = "ray", + backend_config: Optional[Dict[str, ConfigsRecordValues]] = None, + client_app_attr: Optional[str] = None, + server_app_attr: Optional[str] = None, + app_dir: str = "", + driver_api_address: str = "0.0.0.0:9091", + enable_tf_gpu_growth: bool = False, + verbose_logging: bool = False, +) -> None: + r"""Launch the Simulation Engine. + + Parameters + ---------- + num_supernodes : int + Number of nodes that run a ClientApp. They can be sampled by a + Driver in the ServerApp and receive a Message describing what the ClientApp + should perform. + + client_app : Optional[ClientApp] + The `ClientApp` to be executed by each of the `SuperNodes`. It will receive + messages sent by the `ServerApp`. + + server_app : Optional[ServerApp] + The `ServerApp` to be executed. + + backend_name : str (default: ray) + A simulation backend that runs `ClientApp`s. + + backend_config : Optional[Dict[str, ConfigsRecordValues]] + 'A dictionary, e.g {"":, "":} to configure a + backend. Values supported in are those included by + `flwr.common.typing.ConfigsRecordValues`. + + client_app_attr : str + A path to a `ClientApp` module to be loaded: For example: `client:app` or + `project.package.module:wrapper.app`." + + server_app_attr : str + A path to a `ServerApp` module to be loaded: For example: `server:app` or + `project.package.module:wrapper.app`." + + app_dir : str + Add specified directory to the PYTHONPATH and load `ClientApp` from there. + (Default: current working directory.) + + driver_api_address : str (default: "0.0.0.0:9091") + Driver API (gRPC) server address (IPv4, IPv6, or a domain name) + + enable_tf_gpu_growth : bool (default: False) + A boolean to indicate whether to enable GPU growth on the main thread. This is + desirable if you make use of a TensorFlow model on your `ServerApp` while + having your `ClientApp` running on the same GPU. Without enabling this, you + might encounter an out-of-memory error becasue TensorFlow by default allocates + all GPU memory. Read mor about how `tf.config.experimental.set_memory_growth()` + works in the TensorFlow documentation: https://www.tensorflow.org/api/stable. + + verbose_logging : bool (default: False) + When diabled, only INFO, WARNING and ERROR log messages will be shown. If + enabled, DEBUG-level logs will be displayed. + """ + # Set logging level + if not verbose_logging: + logger = logging.getLogger("flwr") + logger.setLevel(INFO) + + if backend_config is None: + backend_config = {} + + if enable_tf_gpu_growth: + # Check that Backend config has also enabled using GPU growth + use_tf = backend_config.get("tensorflow", False) + if not use_tf: + log(WARNING, "Enabling GPU growth for your backend.") + backend_config["tensorflow"] = True + + # Convert config to original JSON-stream format + backend_config_stream = json.dumps(backend_config) + + simulation_engine_th = None + args = ( + num_supernodes, + backend_name, + backend_config_stream, + driver_api_address, + app_dir, + enable_tf_gpu_growth, + client_app, + client_app_attr, + server_app, + server_app_attr, + ) + # Detect if there is an Asyncio event loop already running. + # If yes, run everything on a separate thread. In environmnets + # like Jupyter/Colab notebooks, there is an event loop present. + run_in_thread = False + try: + _ = ( + asyncio.get_running_loop() + ) # Raises RuntimeError if no event loop is present + log(DEBUG, "Asyncio event loop already running.") + + run_in_thread = True + + except RuntimeError: + log(DEBUG, "No asyncio event loop runnig") + + finally: + if run_in_thread: + log(DEBUG, "Starting Simulation Engine on a new thread.") + simulation_engine_th = threading.Thread(target=_main_loop, args=args) + simulation_engine_th.start() + simulation_engine_th.join() + else: + log(DEBUG, "Starting Simulation Engine on the main thread.") + _main_loop(*args) + + +def _parse_args_run_simulation() -> argparse.ArgumentParser: + """Parse flower-simulation command line arguments.""" + parser = argparse.ArgumentParser( + description="Start a Flower simulation", + ) + parser.add_argument( + "--server-app", + required=True, + help="For example: `server:app` or `project.package.module:wrapper.app`", + ) + parser.add_argument( + "--client-app", + required=True, + help="For example: `client:app` or `project.package.module:wrapper.app`", + ) + parser.add_argument( + "--num-supernodes", + type=int, + required=True, + help="Number of simulated SuperNodes.", + ) + parser.add_argument( + "--driver-api-address", + default="0.0.0.0:9091", + type=str, + help="For example: `server:app` or `project.package.module:wrapper.app`", + ) + parser.add_argument( + "--backend", + default="ray", + type=str, + help="Simulation backend that executes the ClientApp.", + ) + parser.add_argument( + "--backend-config", + type=str, + default='{"client_resources": {"num_cpus":2, "num_gpus":0.0}, "tensorflow": 0}', + help='A JSON formatted stream, e.g \'{"":, "":}\' to ' + "configure a backend. Values supported in are those included by " + "`flwr.common.typing.ConfigsRecordValues`. ", + ) + parser.add_argument( + "--enable-tf-gpu-growth", + action="store_true", + help="Enables GPU growth on the main thread. This is desirable if you make " + "use of a TensorFlow model on your `ServerApp` while having your `ClientApp` " + "running on the same GPU. Without enabling this, you might encounter an " + "out-of-memory error because TensorFlow by default allocates all GPU memory." + "Read more about how `tf.config.experimental.set_memory_growth()` works in " + "the TensorFlow documentation: https://www.tensorflow.org/api/stable.", + ) + parser.add_argument( + "--verbose", + action="store_true", + help="When unset, only INFO, WARNING and ERROR log messages will be shown. " + "If set, DEBUG-level logs will be displayed. ", + ) + parser.add_argument( + "--app-dir", + default="", + help="Add specified directory to the PYTHONPATH and load" + "ClientApp and ServerApp from there." + " Default: current working directory.", + ) + + return parser diff --git a/src/py/flwr_example/tensorflow_fashion_mnist/fashion_mnist_test.py b/src/py/flwr_example/tensorflow_fashion_mnist/fashion_mnist_test.py index 1213410fbc34..f6b922b27eab 100644 --- a/src/py/flwr_example/tensorflow_fashion_mnist/fashion_mnist_test.py +++ b/src/py/flwr_example/tensorflow_fashion_mnist/fashion_mnist_test.py @@ -21,7 +21,7 @@ def test_shuffle() -> None: - """Test if shuffle is deterministic depending on the the provided seed.""" + """Test if shuffle is deterministic depending on the provided seed.""" # Prepare x_tt = np.arange(8) y_tt = np.arange(8) diff --git a/src/py/flwr_experimental/baseline/config/config.py b/src/py/flwr_experimental/baseline/config/config.py index 5170ea7f7e26..16c144bb6a2f 100644 --- a/src/py/flwr_experimental/baseline/config/config.py +++ b/src/py/flwr_experimental/baseline/config/config.py @@ -23,7 +23,7 @@ from flwr_experimental.ops.instance import Instance # We assume that devices which are older will have at most -# ~80% of the the Samsung Galaxy Note 5 compute performance. +# ~80% of the Samsung Galaxy Note 5 compute performance. SCORE_MISSING = int(226 * 0.80) DEVICE_DISTRIBUTION = [ diff --git a/src/py/flwr_experimental/baseline/tf_cifar/settings.py b/src/py/flwr_experimental/baseline/tf_cifar/settings.py index 829856af0ace..ed1a72cafac9 100644 --- a/src/py/flwr_experimental/baseline/tf_cifar/settings.py +++ b/src/py/flwr_experimental/baseline/tf_cifar/settings.py @@ -120,9 +120,9 @@ def configure_clients( cid=str(i), partition=i, # Indices 0 to 49 fast, 50 to 99 slow - delay_factor=delay_factor_fast - if i < int(num_clients / 2) - else delay_factor_slow, + delay_factor=( + delay_factor_fast if i < int(num_clients / 2) else delay_factor_slow + ), # Shared iid_fraction=iid_fraction, num_clients=num_clients, diff --git a/src/py/flwr_experimental/baseline/tf_fashion_mnist/settings.py b/src/py/flwr_experimental/baseline/tf_fashion_mnist/settings.py index b0de9841de30..72adc1f0be04 100644 --- a/src/py/flwr_experimental/baseline/tf_fashion_mnist/settings.py +++ b/src/py/flwr_experimental/baseline/tf_fashion_mnist/settings.py @@ -141,9 +141,9 @@ def configure_clients( cid=str(i), partition=i, # Indices 0 to 49 fast, 50 to 99 slow - delay_factor=delay_factor_fast - if i < int(num_clients / 2) - else delay_factor_slow, + delay_factor=( + delay_factor_fast if i < int(num_clients / 2) else delay_factor_slow + ), # Shared iid_fraction=iid_fraction, num_clients=num_clients, diff --git a/src/py/flwr_experimental/baseline/tf_hotkey/settings.py b/src/py/flwr_experimental/baseline/tf_hotkey/settings.py index 18c58d1333a5..5bfb7b1e42ad 100644 --- a/src/py/flwr_experimental/baseline/tf_hotkey/settings.py +++ b/src/py/flwr_experimental/baseline/tf_hotkey/settings.py @@ -144,9 +144,9 @@ def configure_clients( cid=str(i), partition=i, # Indices 0 to 49 fast, 50 to 99 slow - delay_factor=delay_factor_fast - if i < int(num_clients / 2) - else delay_factor_slow, + delay_factor=( + delay_factor_fast if i < int(num_clients / 2) else delay_factor_slow + ), # Shared iid_fraction=iid_fraction, num_clients=num_clients, diff --git a/src/py/flwr_experimental/logserver/server.py b/src/py/flwr_experimental/logserver/server.py index dc729cebbf33..683b12b6db6c 100644 --- a/src/py/flwr_experimental/logserver/server.py +++ b/src/py/flwr_experimental/logserver/server.py @@ -69,9 +69,9 @@ def upload_file(local_filepath: str, s3_key: Optional[str]) -> None: Bucket=CONFIG["s3_bucket"], Key=s3_key, ExtraArgs={ - "ContentType": "application/pdf" - if s3_key.endswith(".pdf") - else "text/plain" + "ContentType": ( + "application/pdf" if s3_key.endswith(".pdf") else "text/plain" + ) }, ) # pylint: disable=broad-except diff --git a/src/py/flwr_experimental/ops/__init__.py b/src/py/flwr_experimental/ops/__init__.py index b56c757e0207..bad31028e68c 100644 --- a/src/py/flwr_experimental/ops/__init__.py +++ b/src/py/flwr_experimental/ops/__init__.py @@ -13,4 +13,4 @@ # limitations under the License. # ============================================================================== """Flower ops provides an opinionated way to provision necessary compute -infrastructure for running Flower workloads.""" +infrastructure for running Flower runs.""" diff --git a/src/py/flwr_experimental/ops/cluster.py b/src/py/flwr_experimental/ops/cluster.py index dfd0dd13727e..53a4e9617427 100644 --- a/src/py/flwr_experimental/ops/cluster.py +++ b/src/py/flwr_experimental/ops/cluster.py @@ -73,10 +73,10 @@ def ssh_connection( username, key_filename = ssh_credentials instance_ssh_port: int = cast(int, instance.ssh_port) - ignore_host_key_policy: Union[ - Type[MissingHostKeyPolicy], MissingHostKeyPolicy - ] = cast( - Union[Type[MissingHostKeyPolicy], MissingHostKeyPolicy], IgnoreHostKeyPolicy + ignore_host_key_policy: Union[Type[MissingHostKeyPolicy], MissingHostKeyPolicy] = ( + cast( + Union[Type[MissingHostKeyPolicy], MissingHostKeyPolicy], IgnoreHostKeyPolicy + ) ) client = SSHClient() @@ -139,7 +139,7 @@ def group_instances_by_specs(instances: List[Instance]) -> List[List[Instance]]: class Cluster: - """Compute enviroment independend compute cluster.""" + """Compute environment independend compute cluster.""" def __init__( self, diff --git a/src/py/flwr_tool/init_py_check.py b/src/py/flwr_tool/init_py_check.py index 8cdc2e0ab5be..67425139f991 100755 --- a/src/py/flwr_tool/init_py_check.py +++ b/src/py/flwr_tool/init_py_check.py @@ -36,7 +36,7 @@ def check_missing_init_files(absolute_path: str) -> None: if __name__ == "__main__": if len(sys.argv) == 0: - raise Exception( + raise Exception( # pylint: disable=W0719 "Please provide at least one directory path relative to your current working directory." ) for i, _ in enumerate(sys.argv): diff --git a/src/py/flwr_tool/protoc.py b/src/py/flwr_tool/protoc.py index 5d3ce942c1e0..b0b078c2eae4 100644 --- a/src/py/flwr_tool/protoc.py +++ b/src/py/flwr_tool/protoc.py @@ -51,7 +51,7 @@ def compile_all() -> None: exit_code = protoc.main(command) if exit_code != 0: - raise Exception(f"Error: {command} failed") + raise Exception(f"Error: {command} failed") # pylint: disable=W0719 if __name__ == "__main__": diff --git a/src/py/flwr_tool/protoc_test.py b/src/py/flwr_tool/protoc_test.py index 57ca3ff423c2..2d48582eb441 100644 --- a/src/py/flwr_tool/protoc_test.py +++ b/src/py/flwr_tool/protoc_test.py @@ -28,4 +28,4 @@ def test_directories() -> None: def test_proto_file_count() -> None: """Test if the correct number of proto files were captured by the glob.""" - assert len(PROTO_FILES) == 5 + assert len(PROTO_FILES) == 7 diff --git a/src/py/flwr_tool/update_changelog.py b/src/py/flwr_tool/update_changelog.py new file mode 100644 index 000000000000..e3cffff7e36c --- /dev/null +++ b/src/py/flwr_tool/update_changelog.py @@ -0,0 +1,243 @@ +# mypy: ignore-errors +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""This module is used to update the changelog.""" + + +import re +from sys import argv + +from github import Github + +REPO_NAME = "adap/flower" +CHANGELOG_FILE = "doc/source/ref-changelog.md" +CHANGELOG_SECTION_HEADER = "### Changelog entry" + + +def _get_latest_tag(gh_api): + """Retrieve the latest tag from the GitHub repository.""" + repo = gh_api.get_repo(REPO_NAME) + tags = repo.get_tags() + return tags[0] if tags.totalCount > 0 else None + + +def _get_pull_requests_since_tag(gh_api, tag): + """Get a list of pull requests merged into the main branch since a given tag.""" + repo = gh_api.get_repo(REPO_NAME) + commits = {commit.sha for commit in repo.compare(tag.commit.sha, "main").commits} + prs = set() + for pr_info in repo.get_pulls( + state="closed", sort="created", direction="desc", base="main" + ): + if pr_info.merge_commit_sha in commits: + prs.add(pr_info) + if len(prs) == len(commits): + break + return prs + + +def _format_pr_reference(title, number, url): + """Format a pull request reference as a markdown list item.""" + return f"- **{title.replace('*', '')}** ([#{number}]({url}))" + + +def _extract_changelog_entry(pr_info): + """Extract the changelog entry from a pull request's body.""" + if not pr_info.body: + return None, "general" + + entry_match = re.search( + f"{CHANGELOG_SECTION_HEADER}(.+?)(?=##|$)", pr_info.body, re.DOTALL + ) + if not entry_match: + return None, None + + entry_text = entry_match.group(1).strip() + + # Remove markdown comments + entry_text = re.sub(r"", "", entry_text, flags=re.DOTALL).strip() + + token_markers = { + "general": "", + "skip": "", + "baselines": "", + "examples": "", + "sdk": "", + "simulations": "", + } + + # Find the token based on the presence of its marker in entry_text + token = next( + (token for token, marker in token_markers.items() if marker in entry_text), None + ) + + return entry_text, token + + +def _update_changelog(prs): + """Update the changelog file with entries from provided pull requests.""" + with open(CHANGELOG_FILE, "r+", encoding="utf-8") as file: + content = file.read() + unreleased_index = content.find("## Unreleased") + + if unreleased_index == -1: + print("Unreleased header not found in the changelog.") + return + + # Find the end of the Unreleased section + next_header_index = content.find("##", unreleased_index + 1) + next_header_index = ( + next_header_index if next_header_index != -1 else len(content) + ) + + for pr_info in prs: + pr_entry_text, category = _extract_changelog_entry(pr_info) + + # Skip if PR should be skipped or already in changelog + if category == "skip" or f"#{pr_info.number}]" in content: + continue + + pr_reference = _format_pr_reference( + pr_info.title, pr_info.number, pr_info.html_url + ) + + # Process based on category + if category in ["general", "baselines", "examples", "sdk", "simulations"]: + entry_title = _get_category_title(category) + content = _update_entry( + content, + entry_title, + pr_info, + unreleased_index, + next_header_index, + ) + + elif pr_entry_text: + content = _insert_new_entry( + content, pr_info, pr_reference, pr_entry_text, unreleased_index + ) + + else: + content = _insert_entry_no_desc(content, pr_reference, unreleased_index) + + next_header_index = content.find("##", unreleased_index + 1) + next_header_index = ( + next_header_index if next_header_index != -1 else len(content) + ) + + # Finalize content update + file.seek(0) + file.write(content) + file.truncate() + + print("Changelog updated.") + + +def _get_category_title(category): + """Get the title of a changelog section based on its category.""" + headers = { + "general": "General improvements", + "baselines": "General updates to Flower Baselines", + "examples": "General updates to Flower Examples", + "sdk": "General updates to Flower SDKs", + "simulations": "General updates to Flower Simulations", + } + return headers.get(category, "") + + +def _update_entry( + content, category_title, pr_info, unreleased_index, next_header_index +): + """Update a specific section in the changelog content.""" + if ( + section_index := content.find( + category_title, unreleased_index, next_header_index + ) + ) != -1: + newline_index = content.find("\n", section_index) + closing_parenthesis_index = content.rfind(")", unreleased_index, newline_index) + updated_entry = f", [{pr_info.number}]({pr_info.html_url})" + content = ( + content[:closing_parenthesis_index] + + updated_entry + + content[closing_parenthesis_index:] + ) + else: + new_section = ( + f"\n- **{category_title}** ([#{pr_info.number}]({pr_info.html_url}))\n" + ) + insert_index = content.find("\n", unreleased_index) + 1 + content = content[:insert_index] + new_section + content[insert_index:] + return content + + +def _insert_new_entry(content, pr_info, pr_reference, pr_entry_text, unreleased_index): + """Insert a new entry into the changelog.""" + if (existing_entry_start := content.find(pr_entry_text)) != -1: + pr_ref_end = content.rfind("\n", 0, existing_entry_start) + updated_entry = ( + f"{content[pr_ref_end]}\n, [{pr_info.number}]({pr_info.html_url})" + ) + content = content[:pr_ref_end] + updated_entry + content[existing_entry_start:] + else: + insert_index = content.find("\n", unreleased_index) + 1 + + # Split the pr_entry_text into paragraphs + paragraphs = pr_entry_text.split("\n") + + # Indent each paragraph + indented_paragraphs = [ + " " + paragraph if paragraph else paragraph for paragraph in paragraphs + ] + + # Join the paragraphs back together, ensuring each is separated by a newline + indented_pr_entry_text = "\n".join(indented_paragraphs) + + content = ( + content[:insert_index] + + "\n" + + pr_reference + + "\n\n" + + indented_pr_entry_text + + "\n" + + content[insert_index:] + ) + return content + + +def _insert_entry_no_desc(content, pr_reference, unreleased_index): + """Insert a changelog entry for a pull request with no specific description.""" + insert_index = content.find("\n", unreleased_index) + 1 + content = ( + content[:insert_index] + "\n" + pr_reference + "\n" + content[insert_index:] + ) + return content + + +def main(): + """Update changelog using the descriptions of PRs since the latest tag.""" + # Initialize GitHub Client with provided token (as argument) + gh_api = Github(argv[1]) + latest_tag = _get_latest_tag(gh_api) + if not latest_tag: + print("No tags found in the repository.") + return + + prs = _get_pull_requests_since_tag(gh_api, latest_tag) + _update_changelog(prs) + + +if __name__ == "__main__": + main()