diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index 61c8d14039db..03849995735d 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -31,7 +31,7 @@ RUN apt-get install -y curl wget gnupg python3 python-is-python3 python3-pip git RUN python -m pip install \ pip==23.3.1 \ setuptools==68.2.2 \ - poetry==1.5.1 + poetry==1.7.1 USER $USERNAME ENV PATH="/home/$USERNAME/.local/bin:${PATH}" diff --git a/.github/ISSUE_TEMPLATE/baseline_request.yml b/.github/ISSUE_TEMPLATE/baseline_request.yml index 49ae922a94ad..8df8eec22a21 100644 --- a/.github/ISSUE_TEMPLATE/baseline_request.yml +++ b/.github/ISSUE_TEMPLATE/baseline_request.yml @@ -38,18 +38,18 @@ body: attributes: label: For first time contributors value: | - - [ ] Read the [`first contribution` doc](https://flower.dev/docs/first-time-contributors.html) + - [ ] Read the [`first contribution` doc](https://flower.ai/docs/first-time-contributors.html) - [ ] Complete the Flower tutorial - [ ] Read the Flower Baselines docs to get an overview: - - [ ] [How to use Flower Baselines](https://flower.dev/docs/baselines/how-to-use-baselines.html) - - [ ] [How to contribute a Flower Baseline](https://flower.dev/docs/baselines/how-to-contribute-baselines.html) + - [ ] [How to use Flower Baselines](https://flower.ai/docs/baselines/how-to-use-baselines.html) + - [ ] [How to contribute a Flower Baseline](https://flower.ai/docs/baselines/how-to-contribute-baselines.html) - type: checkboxes attributes: label: Prepare - understand the scope options: - label: Read the paper linked above - label: Decide which experiments you'd like to reproduce. The more the better! - - label: Follow the steps outlined in [Add a new Flower Baseline](https://flower.dev/docs/baselines/how-to-contribute-baselines.html#add-a-new-flower-baseline). + - label: Follow the steps outlined in [Add a new Flower Baseline](https://flower.ai/docs/baselines/how-to-contribute-baselines.html#add-a-new-flower-baseline). - label: You can use as reference [other baselines](https://github.com/adap/flower/tree/main/baselines) that the community merged following those steps. - type: checkboxes attributes: diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index d0ab46d60b57..9e0ed23d4dbb 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -1,14 +1,14 @@ blank_issues_enabled: false contact_links: - name: Slack Channel - url: https://flower.dev/join-slack + url: https://flower.ai/join-slack about: Connect with other Flower users and contributors and discuss with them or ask them questions. - name: Discussion url: https://github.com/adap/flower/discussions - about: Ask about new features or general questions. Please use the discussion area in most of the cases instead of the issues. + about: Ask about new features or general questions. Please use the discussion area in most of the cases instead of the issues. - name: Flower Issues url: https://github.com/adap/flower/issues about: Contribute new features/enhancements, report bugs, or improve the documentation. - name: Flower Mail - url: https://flower.dev/ - about: If your project needs professional support please contact the Flower team (hello@flower.dev). + url: https://flower.ai/ + about: If your project needs professional support please contact the Flower team (hello@flower.ai). diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index b9d4a0a23e23..0077bbab0909 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -36,10 +36,10 @@ Example: The variable `rnd` was renamed to `server_round` to improve readability - [ ] Implement proposed change - [ ] Write tests -- [ ] Update [documentation](https://flower.dev/docs/writing-documentation.html) +- [ ] Update [documentation](https://flower.ai/docs/writing-documentation.html) - [ ] Update the changelog entry below - [ ] Make CI checks pass -- [ ] Ping maintainers on [Slack](https://flower.dev/join-slack/) (channel `#contributions`) +- [ ] Ping maintainers on [Slack](https://flower.ai/join-slack/) (channel `#contributions`) diff --git a/.github/actions/bootstrap/action.yml b/.github/actions/bootstrap/action.yml index 5b0c10f53f5e..7b1716bbc954 100644 --- a/.github/actions/bootstrap/action.yml +++ b/.github/actions/bootstrap/action.yml @@ -12,7 +12,7 @@ inputs: default: 68.2.2 poetry-version: description: "Version of poetry to be installed using pip" - default: 1.5.1 + default: 1.7.1 outputs: python-version: description: "Version range or exact version of Python or PyPy" diff --git a/.github/workflows/_docker-build.yml b/.github/workflows/_docker-build.yml index 4a1289d9175a..99ec0671db66 100644 --- a/.github/workflows/_docker-build.yml +++ b/.github/workflows/_docker-build.yml @@ -114,7 +114,7 @@ jobs: metadata: ${{ steps.meta.outputs.json }} steps: - name: Download digests - uses: actions/download-artifact@6b208ae046db98c579e8a3aa621ab581ff575935 # v4.1.1 + uses: actions/download-artifact@eaceaf801fd36c7dee90939fad912460b18a1ffe # v4.1.2 with: pattern: digests-${{ needs.build.outputs.build-id }}-* path: /tmp/digests diff --git a/.github/workflows/baselines.yml b/.github/workflows/baselines.yml index bfb26053836d..c4485fe72d10 100644 --- a/.github/workflows/baselines.yml +++ b/.github/workflows/baselines.yml @@ -1,14 +1,5 @@ name: Baselines -# The aim of this workflow is to test only the changed (or added) baseline. -# Here is the rough idea of how it works (more details are presented later in the comments): -# 1. Checks for the changes between the current branch and the main - in case of PR - -# or between the HEAD and HEAD~1 (main last commit and the previous one) - in case of -# a push to main. -# 2. Fails the test if there are changes to more than one baseline. Passes the test -# (skips the rests) if there are no changes to any baselines. Follows the test if only -# one baseline is added or modified. -# 3. Sets up the env specified for the baseline. -# 4. Runs the tests. + on: push: branches: @@ -24,112 +15,75 @@ concurrency: env: FLWR_TELEMETRY_ENABLED: 0 -defaults: - run: - working-directory: baselines - jobs: - test_baselines: - name: Test + changes: runs-on: ubuntu-22.04 + permissions: + pull-requests: read + outputs: + baselines: ${{ steps.filter.outputs.changes }} steps: - uses: actions/checkout@v4 - # The depth two of the checkout is needed in case of merging to the main - # because we compare the HEAD (current version) with HEAD~1 (version before - # the PR was merged) - with: - fetch-depth: 2 - - name: Fetch main branch - run: | - # The main branch is needed in case of the PR to make a comparison (by - # default the workflow takes as little information as possible - it does not - # have the history - if [ ${{ github.event_name }} == "pull_request" ] - then - git fetch origin main:main - fi - - name: Find changed/new baselines - id: find_changed_baselines_dirs + + - shell: bash run: | - if [ ${{ github.event_name }} == "push" ] - then - # Push event triggered when merging to main - change_references="HEAD..HEAD~1" - else - # Pull request event triggered for any commit to a pull request - change_references="main..HEAD" - fi - dirs=$(git diff --dirstat=files,0 ${change_references} . | awk '{print $2}' | grep -E '^baselines/[^/]*/$' | \ - grep -v \ - -e '^baselines/dev' \ - -e '^baselines/baseline_template' \ - -e '^baselines/flwr_baselines' \ - -e '^baselines/doc' \ - | sed 's/^baselines\///') - # git diff --dirstat=files,0 ${change_references} . - checks the differences - # and a file is counted as changed if more than 0 lines were changed - # it returns the results in the format x.y% path/to/dir/ - # awk '{print $2}' - takes only the directories (skips the percentages) - # grep -E '^baselines/[^/]*/$' - takes only the paths that start with - # baseline (and have at least one subdirectory) - # grep -v -e ... - excludes the `baseline_template`, `dev`, `flwr_baselines` - # sed 's/^baselines\///' - narrows down the path to baseline/ - echo "Detected changed directories: ${dirs}" - # Save changed dirs to output of this step - EOF=$(dd if=/dev/urandom bs=15 count=1 status=none | base64) - echo "dirs<> "$GITHUB_OUTPUT" - for dir in $dirs - do - echo "$dir" >> "$GITHUB_OUTPUT" - done - echo "EOF" >> "$GITHUB_OUTPUT" - - name: Validate changed/new baselines - id: validate_changed_baselines_dirs + # create a list of all directories in baselines + { + echo 'FILTER_PATHS<> "$GITHUB_ENV" + + - uses: dorny/paths-filter@v3 + id: filter + with: + filters: ${{ env.FILTER_PATHS }} + + - if: ${{ github.event.pull_request.head.repo.fork }} run: | - dirs="${{ steps.find_changed_baselines_dirs.outputs.dirs }}" - dirs_array=() - if [[ -n $dirs ]]; then - while IFS= read -r line; do - dirs_array+=("$line") - done <<< "$dirs" - fi - length=${#dirs_array[@]} - echo "The number of changed baselines is $length" - - if [ $length -gt 1 ]; then - echo "The changes should only apply to a single baseline" - exit 1 + CHANGES=$(echo "${{ toJson(steps.filter.outputs.changes) }}" | jq '. | length') + if [ "$CHANGES" -gt 1 ]; then + echo "::error ::The changes should only apply to a single baseline." + exit 1 fi - - if [ $length -eq 0 ]; then - echo "The baselines were not changed - skipping the remaining steps." - echo "baseline_changed=false" >> "$GITHUB_OUTPUT" - exit 0 - fi - - echo "changed_dir=${dirs[0]}" >> "$GITHUB_OUTPUT" - echo "baseline_changed=true" >> "$GITHUB_OUTPUT" + + test: + runs-on: ubuntu-22.04 + needs: changes + if: ${{ needs.changes.outputs.baselines != '' && toJson(fromJson(needs.changes.outputs.baselines)) != '[]' }} + strategy: + matrix: + baseline: ${{ fromJSON(needs.changes.outputs.baselines) }} + steps: + - uses: actions/checkout@v4 + - name: Bootstrap - if: steps.validate_changed_baselines_dirs.outputs.baseline_changed == 'true' uses: ./.github/actions/bootstrap with: python-version: '3.10' + - name: Install dependencies - if: steps.validate_changed_baselines_dirs.outputs.baseline_changed == 'true' - run: | - changed_dir="${{ steps.validate_changed_baselines_dirs.outputs.changed_dir }}" - cd "${changed_dir}" - python -m poetry install - - name: Test - if: steps.validate_changed_baselines_dirs.outputs.baseline_changed == 'true' - run: | - dir="${{ steps.validate_changed_baselines_dirs.outputs.changed_dir }}" - echo "Testing ${dir}" - ./dev/test-baseline.sh $dir - - name: Test Structure - if: steps.validate_changed_baselines_dirs.outputs.baseline_changed == 'true' - run: | - dir="${{ steps.validate_changed_baselines_dirs.outputs.changed_dir }}" - echo "Testing ${dir}" - ./dev/test-baseline-structure.sh $dir + working-directory: baselines/${{ matrix.baseline }} + run: python -m poetry install + + - name: Testing ${{ matrix.baseline }} + working-directory: baselines + run: ./dev/test-baseline.sh ${{ matrix.baseline }} + - name: Test Structure of ${{ matrix.baseline }} + working-directory: baselines + run: ./dev/test-baseline-structure.sh ${{ matrix.baseline }} diff --git a/.github/workflows/cpp.yml b/.github/workflows/cpp.yml index a1d918aa5b29..7efc879bcee2 100644 --- a/.github/workflows/cpp.yml +++ b/.github/workflows/cpp.yml @@ -19,13 +19,13 @@ jobs: uses: ./.github/actions/bootstrap - name: Cache restore SDK build - uses: actions/cache/restore@v3 + uses: actions/cache/restore@v4 with: path: build/ key: ${{ runner.os }}-sdk-build - name: Cache restore example build - uses: actions/cache/restore@v3 + uses: actions/cache/restore@v4 with: path: examples/quickstart-cpp/build/ key: ${{ runner.os }}-example-build @@ -82,14 +82,14 @@ jobs: fi - name: Cache save SDK build - uses: actions/cache/save@v3 + uses: actions/cache/save@v4 if: github.ref_name == 'main' with: path: build/ key: ${{ runner.os }}-sdk-build - name: Cache save example build - uses: actions/cache/save@v3 + uses: actions/cache/save@v4 if: github.ref_name == 'main' with: path: examples/quickstart-cpp/build/ diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 94f3495a20ef..a4c769fdf850 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -43,8 +43,9 @@ jobs: AWS_DEFAULT_REGION: ${{ secrets. AWS_DEFAULT_REGION }} AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets. AWS_SECRET_ACCESS_KEY }} + DOCS_BUCKET: flower.ai run: | - aws s3 sync --delete --exclude ".*" --exclude "v/*" --cache-control "no-cache" ./doc/build/html/ s3://flower.dev/docs/framework - aws s3 sync --delete --exclude ".*" --exclude "v/*" --cache-control "no-cache" ./baselines/doc/build/html/ s3://flower.dev/docs/baselines - aws s3 sync --delete --exclude ".*" --exclude "v/*" --cache-control "no-cache" ./examples/doc/build/html/ s3://flower.dev/docs/examples - aws s3 sync --delete --exclude ".*" --exclude "v/*" --cache-control "no-cache" ./datasets/doc/build/html/ s3://flower.dev/docs/datasets + aws s3 sync --delete --exclude ".*" --exclude "v/*" --cache-control "no-cache" ./doc/build/html/ s3://${{ env.DOCS_BUCKET }}/docs/framework + aws s3 sync --delete --exclude ".*" --exclude "v/*" --cache-control "no-cache" ./baselines/doc/build/html/ s3://${{ env.DOCS_BUCKET }}/docs/baselines + aws s3 sync --delete --exclude ".*" --exclude "v/*" --cache-control "no-cache" ./examples/doc/build/html/ s3://${{ env.DOCS_BUCKET }}/docs/examples + aws s3 sync --delete --exclude ".*" --exclude "v/*" --cache-control "no-cache" ./datasets/doc/build/html/ s3://${{ env.DOCS_BUCKET }}/docs/datasets diff --git a/.github/workflows/e2e.yml b/.github/workflows/e2e.yml index 065e79fff9ab..db9f65a4f4f3 100644 --- a/.github/workflows/e2e.yml +++ b/.github/workflows/e2e.yml @@ -14,6 +14,7 @@ concurrency: env: FLWR_TELEMETRY_ENABLED: 0 + ARTIFACT_BUCKET: artifact.flower.ai jobs: wheel: @@ -43,7 +44,7 @@ jobs: echo "SHORT_SHA=$sha_short" >> "$GITHUB_OUTPUT" [ -z "${{ github.head_ref }}" ] && dir="${{ github.ref_name }}" || dir="pr/${{ github.head_ref }}" echo "DIR=$dir" >> "$GITHUB_OUTPUT" - aws s3 cp --content-disposition "attachment" --cache-control "no-cache" ./ s3://artifact.flower.dev/py/$dir/$sha_short --recursive + aws s3 cp --content-disposition "attachment" --cache-control "no-cache" ./ s3://${{ env.ARTIFACT_BUCKET }}/py/$dir/$sha_short --recursive outputs: whl_path: ${{ steps.upload.outputs.WHL_PATH }} short_sha: ${{ steps.upload.outputs.SHORT_SHA }} @@ -123,7 +124,7 @@ jobs: - name: Install Flower wheel from artifact store if: ${{ github.repository == 'adap/flower' && !github.event.pull_request.head.repo.fork && github.actor != 'dependabot[bot]' }} run: | - python -m pip install https://artifact.flower.dev/py/${{ needs.wheel.outputs.dir }}/${{ needs.wheel.outputs.short_sha }}/${{ needs.wheel.outputs.whl_path }} + python -m pip install https://${{ env.ARTIFACT_BUCKET }}/py/${{ needs.wheel.outputs.dir }}/${{ needs.wheel.outputs.short_sha }}/${{ needs.wheel.outputs.whl_path }} - name: Download dataset if: ${{ matrix.dataset }} run: python -c "${{ matrix.dataset }}" @@ -164,9 +165,9 @@ jobs: - name: Install Flower wheel from artifact store if: ${{ github.repository == 'adap/flower' && !github.event.pull_request.head.repo.fork && github.actor != 'dependabot[bot]' }} run: | - python -m pip install https://artifact.flower.dev/py/${{ needs.wheel.outputs.dir }}/${{ needs.wheel.outputs.short_sha }}/${{ needs.wheel.outputs.whl_path }} + python -m pip install https://${{ env.ARTIFACT_BUCKET }}/py/${{ needs.wheel.outputs.dir }}/${{ needs.wheel.outputs.short_sha }}/${{ needs.wheel.outputs.whl_path }} - name: Cache Datasets - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: "~/.keras" key: keras-datasets diff --git a/.github/workflows/framework-draft-release.yml b/.github/workflows/framework-draft-release.yml index 959d17249765..91a89953cf96 100644 --- a/.github/workflows/framework-draft-release.yml +++ b/.github/workflows/framework-draft-release.yml @@ -5,6 +5,9 @@ on: tags: - "v*.*.*" +env: + ARTIFACT_BUCKET: artifact.flower.ai + jobs: publish: if: ${{ github.repository == 'adap/flower' }} @@ -26,16 +29,16 @@ jobs: run: | tag_name=$(echo "${GITHUB_REF_NAME}" | cut -c2-) echo "TAG_NAME=$tag_name" >> "$GITHUB_ENV" - + wheel_name="flwr-${tag_name}-py3-none-any.whl" echo "WHEEL_NAME=$wheel_name" >> "$GITHUB_ENV" - + tar_name="flwr-${tag_name}.tar.gz" echo "TAR_NAME=$tar_name" >> "$GITHUB_ENV" - wheel_url="https://artifact.flower.dev/py/main/${GITHUB_SHA::7}/${wheel_name}" - tar_url="https://artifact.flower.dev/py/main/${GITHUB_SHA::7}/${tar_name}" - + wheel_url="https://${{ env.ARTIFACT_BUCKET }}/py/main/${GITHUB_SHA::7}/${wheel_name}" + tar_url="https://${{ env.ARTIFACT_BUCKET }}/py/main/${GITHUB_SHA::7}/${tar_name}" + curl $wheel_url --output $wheel_name curl $tar_url --output $tar_name - name: Upload wheel @@ -44,14 +47,14 @@ jobs: AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets. AWS_SECRET_ACCESS_KEY }} run: | - aws s3 cp --content-disposition "attachment" --cache-control "no-cache" ./${{ env.WHEEL_NAME }} s3://artifact.flower.dev/py/release/v${{ env.TAG_NAME }}/${{ env.WHEEL_NAME }} - aws s3 cp --content-disposition "attachment" --cache-control "no-cache" ./${{ env.TAR_NAME }} s3://artifact.flower.dev/py/release/v${{ env.TAG_NAME }}/${{ env.TAR_NAME }} - + aws s3 cp --content-disposition "attachment" --cache-control "no-cache" ./${{ env.WHEEL_NAME }} s3://${{ env.ARTIFACT_BUCKET }}/py/release/v${{ env.TAG_NAME }}/${{ env.WHEEL_NAME }} + aws s3 cp --content-disposition "attachment" --cache-control "no-cache" ./${{ env.TAR_NAME }} s3://${{ env.ARTIFACT_BUCKET }}/py/release/v${{ env.TAG_NAME }}/${{ env.TAR_NAME }} + - name: Generate body run: | ./dev/get-latest-changelog.sh > body.md cat body.md - + - name: Release uses: softprops/action-gh-release@v1 with: diff --git a/.github/workflows/framework-release.yml b/.github/workflows/framework-release.yml index f052d3a4a928..04b68fd38af9 100644 --- a/.github/workflows/framework-release.yml +++ b/.github/workflows/framework-release.yml @@ -3,11 +3,14 @@ name: Publish `flwr` release on PyPI on: release: types: [released] - + concurrency: group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.event.pull_request.number || github.ref }} cancel-in-progress: true - + +env: + ARTIFACT_BUCKET: artifact.flower.ai + jobs: publish: if: ${{ github.repository == 'adap/flower' }} @@ -28,11 +31,11 @@ jobs: run: | TAG_NAME=$(echo "${GITHUB_REF_NAME}" | cut -c2-) - wheel_name="flwr-${TAG_NAME}-py3-none-any.whl" + wheel_name="flwr-${TAG_NAME}-py3-none-any.whl" tar_name="flwr-${TAG_NAME}.tar.gz" - wheel_url="https://artifact.flower.dev/py/release/v${TAG_NAME}/${wheel_name}" - tar_url="https://artifact.flower.dev/py/release/v${TAG_NAME}/${tar_name}" + wheel_url="https://${{ env.ARTIFACT_BUCKET }}/py/release/v${TAG_NAME}/${wheel_name}" + tar_url="https://${{ env.ARTIFACT_BUCKET }}/py/release/v${TAG_NAME}/${tar_name}" mkdir -p dist diff --git a/CHANGELOG.md b/CHANGELOG.md index 264548c87669..1f01c0e9717d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,5 @@ # Changelog -Flower changes are tracked as part of the documentation: [Flower Changelog](https://flower.dev/docs/changelog.html). +Flower changes are tracked as part of the documentation: [Flower Changelog](https://flower.ai/docs/changelog.html). The changelog source can be edited here: `doc/source/changelog.rst` diff --git a/README.md b/README.md index e4433e517b88..b5c8bfdaa4b9 100644 --- a/README.md +++ b/README.md @@ -1,16 +1,16 @@ # Flower: A Friendly Federated Learning Framework

- - Flower Website + + Flower Website

- Website | - Blog | - Docs | - Conference | - Slack + Website | + Blog | + Docs | + Conference | + Slack

@@ -18,7 +18,7 @@ [![PRs Welcome](https://img.shields.io/badge/PRs-welcome-brightgreen.svg)](https://github.com/adap/flower/blob/main/CONTRIBUTING.md) ![Build](https://github.com/adap/flower/actions/workflows/framework.yml/badge.svg) [![Downloads](https://static.pepy.tech/badge/flwr)](https://pepy.tech/project/flwr) -[![Slack](https://img.shields.io/badge/Chat-Slack-red)](https://flower.dev/join-slack) +[![Slack](https://img.shields.io/badge/Chat-Slack-red)](https://flower.ai/join-slack) Flower (`flwr`) is a framework for building federated learning systems. The design of Flower is based on a few guiding principles: @@ -39,7 +39,7 @@ design of Flower is based on a few guiding principles: - **Understandable**: Flower is written with maintainability in mind. The community is encouraged to both read and contribute to the codebase. -Meet the Flower community on [flower.dev](https://flower.dev)! +Meet the Flower community on [flower.ai](https://flower.ai)! ## Federated Learning Tutorial @@ -73,19 +73,19 @@ Stay tuned, more tutorials are coming soon. Topics include **Privacy and Securit ## Documentation -[Flower Docs](https://flower.dev/docs): +[Flower Docs](https://flower.ai/docs): -- [Installation](https://flower.dev/docs/framework/how-to-install-flower.html) -- [Quickstart (TensorFlow)](https://flower.dev/docs/framework/tutorial-quickstart-tensorflow.html) -- [Quickstart (PyTorch)](https://flower.dev/docs/framework/tutorial-quickstart-pytorch.html) -- [Quickstart (Hugging Face)](https://flower.dev/docs/framework/tutorial-quickstart-huggingface.html) -- [Quickstart (PyTorch Lightning)](https://flower.dev/docs/framework/tutorial-quickstart-pytorch-lightning.html) -- [Quickstart (Pandas)](https://flower.dev/docs/framework/tutorial-quickstart-pandas.html) -- [Quickstart (fastai)](https://flower.dev/docs/framework/tutorial-quickstart-fastai.html) -- [Quickstart (JAX)](https://flower.dev/docs/framework/tutorial-quickstart-jax.html) -- [Quickstart (scikit-learn)](https://flower.dev/docs/framework/tutorial-quickstart-scikitlearn.html) -- [Quickstart (Android [TFLite])](https://flower.dev/docs/framework/tutorial-quickstart-android.html) -- [Quickstart (iOS [CoreML])](https://flower.dev/docs/framework/tutorial-quickstart-ios.html) +- [Installation](https://flower.ai/docs/framework/how-to-install-flower.html) +- [Quickstart (TensorFlow)](https://flower.ai/docs/framework/tutorial-quickstart-tensorflow.html) +- [Quickstart (PyTorch)](https://flower.ai/docs/framework/tutorial-quickstart-pytorch.html) +- [Quickstart (Hugging Face)](https://flower.ai/docs/framework/tutorial-quickstart-huggingface.html) +- [Quickstart (PyTorch Lightning)](https://flower.ai/docs/framework/tutorial-quickstart-pytorch-lightning.html) +- [Quickstart (Pandas)](https://flower.ai/docs/framework/tutorial-quickstart-pandas.html) +- [Quickstart (fastai)](https://flower.ai/docs/framework/tutorial-quickstart-fastai.html) +- [Quickstart (JAX)](https://flower.ai/docs/framework/tutorial-quickstart-jax.html) +- [Quickstart (scikit-learn)](https://flower.ai/docs/framework/tutorial-quickstart-scikitlearn.html) +- [Quickstart (Android [TFLite])](https://flower.ai/docs/framework/tutorial-quickstart-android.html) +- [Quickstart (iOS [CoreML])](https://flower.ai/docs/framework/tutorial-quickstart-ios.html) ## Flower Baselines @@ -101,6 +101,7 @@ Flower Baselines is a collection of community-contributed projects that reproduc - [FedNova](https://github.com/adap/flower/tree/main/baselines/fednova) - [HeteroFL](https://github.com/adap/flower/tree/main/baselines/heterofl) - [FedAvgM](https://github.com/adap/flower/tree/main/baselines/fedavgm) +- [FedStar](https://github.com/adap/flower/tree/main/baselines/fedstar) - [FedWav2vec2](https://github.com/adap/flower/tree/main/baselines/fedwav2vec2) - [FjORD](https://github.com/adap/flower/tree/main/baselines/fjord) - [MOON](https://github.com/adap/flower/tree/main/baselines/moon) @@ -112,9 +113,9 @@ Flower Baselines is a collection of community-contributed projects that reproduc - [FedAvg](https://github.com/adap/flower/tree/main/baselines/flwr_baselines/flwr_baselines/publications/fedavg_mnist) - [FedOpt](https://github.com/adap/flower/tree/main/baselines/flwr_baselines/flwr_baselines/publications/adaptive_federated_optimization) -Please refer to the [Flower Baselines Documentation](https://flower.dev/docs/baselines/) for a detailed categorization of baselines and for additional info including: -* [How to use Flower Baselines](https://flower.dev/docs/baselines/how-to-use-baselines.html) -* [How to contribute a new Flower Baseline](https://flower.dev/docs/baselines/how-to-contribute-baselines.html) +Please refer to the [Flower Baselines Documentation](https://flower.ai/docs/baselines/) for a detailed categorization of baselines and for additional info including: +* [How to use Flower Baselines](https://flower.ai/docs/baselines/how-to-use-baselines.html) +* [How to contribute a new Flower Baseline](https://flower.ai/docs/baselines/how-to-contribute-baselines.html) ## Flower Usage Examples @@ -148,10 +149,11 @@ Other [examples](https://github.com/adap/flower/tree/main/examples): - Single-Machine Simulation of Federated Learning Systems ([PyTorch](https://github.com/adap/flower/tree/main/examples/simulation-pytorch)) ([Tensorflow](https://github.com/adap/flower/tree/main/examples/simulation-tensorflow)) - [Comprehensive Flower+XGBoost](https://github.com/adap/flower/tree/main/examples/xgboost-comprehensive) - [Flower through Docker Compose and with Grafana dashboard](https://github.com/adap/flower/tree/main/examples/flower-via-docker-compose) +- [Flower with KaplanMeierFitter from the lifelines library](https://github.com/adap/flower/tree/main/examples/federated-kaplna-meier-fitter) ## Community -Flower is built by a wonderful community of researchers and engineers. [Join Slack](https://flower.dev/join-slack) to meet them, [contributions](#contributing-to-flower) are welcome. +Flower is built by a wonderful community of researchers and engineers. [Join Slack](https://flower.ai/join-slack) to meet them, [contributions](#contributing-to-flower) are welcome. diff --git a/baselines/README.md b/baselines/README.md index a18c0553b2b4..3a84df02d8de 100644 --- a/baselines/README.md +++ b/baselines/README.md @@ -1,7 +1,7 @@ # Flower Baselines -> We are changing the way we structure the Flower baselines. While we complete the transition to the new format, you can still find the existing baselines in the `flwr_baselines` directory. Currently, you can make use of baselines for [FedAvg](https://github.com/adap/flower/tree/main/baselines/flwr_baselines/flwr_baselines/publications/fedavg_mnist), [FedOpt](https://github.com/adap/flower/tree/main/baselines/flwr_baselines/flwr_baselines/publications/adaptive_federated_optimization), and [LEAF-FEMNIST](https://github.com/adap/flower/tree/main/baselines/flwr_baselines/flwr_baselines/publications/leaf/femnist). +> We are changing the way we structure the Flower baselines. While we complete the transition to the new format, you can still find the existing baselines in the `flwr_baselines` directory. Currently, you can make use of baselines for [FedAvg](https://github.com/adap/flower/tree/main/baselines/flwr_baselines/flwr_baselines/publications/fedavg_mnist), [FedOpt](https://github.com/adap/flower/tree/main/baselines/flwr_baselines/flwr_baselines/publications/adaptive_federated_optimization), and [LEAF-FEMNIST](https://github.com/adap/flower/tree/main/baselines/flwr_baselines/flwr_baselines/publications/leaf/femnist). > The documentation below has been updated to reflect the new way of using Flower baselines. @@ -23,7 +23,7 @@ Please note that some baselines might include additional files (e.g. a `requirem ## Running the baselines -Each baseline is self-contained in its own directory. Furthermore, each baseline defines its own Python environment using [Poetry](https://python-poetry.org/docs/) via a `pyproject.toml` file and [`pyenv`](https://github.com/pyenv/pyenv). If you haven't setup `Poetry` and `pyenv` already on your machine, please take a look at the [Documentation](https://flower.dev/docs/baselines/how-to-use-baselines.html#setting-up-your-machine) for a guide on how to do so. +Each baseline is self-contained in its own directory. Furthermore, each baseline defines its own Python environment using [Poetry](https://python-poetry.org/docs/) via a `pyproject.toml` file and [`pyenv`](https://github.com/pyenv/pyenv). If you haven't setup `Poetry` and `pyenv` already on your machine, please take a look at the [Documentation](https://flower.ai/docs/baselines/how-to-use-baselines.html#setting-up-your-machine) for a guide on how to do so. Assuming `pyenv` and `Poetry` are already installed on your system. Running a baseline can be done by: @@ -54,7 +54,7 @@ The steps to follow are: ```bash # This will create a new directory with the same structure as `baseline_template`. ./dev/create-baseline.sh - ``` + ``` 3. Then, go inside your baseline directory and continue with the steps detailed in `EXTENDED_README.md` and `README.md`. 4. Once your code is ready and you have checked that following the instructions in your `README.md` the Python environment can be created correctly and that running the code following your instructions can reproduce the experiments in the paper, you just need to create a Pull Request (PR). Then, the process to merge your baseline into the Flower repo will begin! diff --git a/baselines/baseline_template/pyproject.toml b/baselines/baseline_template/pyproject.toml index da7516437f09..31f1ee7bfe6d 100644 --- a/baselines/baseline_template/pyproject.toml +++ b/baselines/baseline_template/pyproject.toml @@ -7,11 +7,11 @@ name = "" # <----- Ensure it matches the name of your baseline di version = "1.0.0" description = "Flower Baselines" license = "Apache-2.0" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] readme = "README.md" -homepage = "https://flower.dev" +homepage = "https://flower.ai" repository = "https://github.com/adap/flower" -documentation = "https://flower.dev" +documentation = "https://flower.ai" classifiers = [ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", @@ -42,9 +42,9 @@ flwr = { extras = ["simulation"], version = "1.5.0" } hydra-core = "1.3.2" # don't change this [tool.poetry.dev-dependencies] -isort = "==5.11.5" -black = "==23.1.0" -docformatter = "==1.5.1" +isort = "==5.13.2" +black = "==24.2.0" +docformatter = "==1.7.5" mypy = "==1.4.1" pylint = "==2.8.2" flake8 = "==3.9.2" @@ -80,10 +80,10 @@ plugins = "numpy.typing.mypy_plugin" [tool.pylint."MESSAGES CONTROL"] disable = "bad-continuation,duplicate-code,too-few-public-methods,useless-import-alias" good-names = "i,j,k,_,x,y,X,Y" -signature-mutators="hydra.main.main" +signature-mutators = "hydra.main.main" [tool.pylint.typecheck] -generated-members="numpy.*, torch.*, tensorflow.*" +generated-members = "numpy.*, torch.*, tensorflow.*" [[tool.mypy.overrides]] module = [ diff --git a/baselines/dasha/pyproject.toml b/baselines/dasha/pyproject.toml index 3ef24e4b985a..f03ad06e26e4 100644 --- a/baselines/dasha/pyproject.toml +++ b/baselines/dasha/pyproject.toml @@ -9,9 +9,9 @@ description = "DASHA: Distributed nonconvex optimization with communication comp license = "Apache-2.0" authors = ["Alexander Tyurin "] readme = "README.md" -homepage = "https://flower.dev" +homepage = "https://flower.ai" repository = "https://github.com/adap/flower" -documentation = "https://flower.dev" +documentation = "https://flower.ai" classifiers = [ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", @@ -49,9 +49,9 @@ torchvision = [{ url = "https://download.pytorch.org/whl/cu118/torchvision-0.15. { url = "https://download.pytorch.org/whl/cpu/torchvision-0.15.0-cp39-cp39-macosx_11_0_arm64.whl", markers="sys_platform == 'darwin'"}] [tool.poetry.dev-dependencies] -isort = "==5.11.5" -black = "==23.1.0" -docformatter = "==1.5.1" +isort = "==5.13.2" +black = "==24.2.0" +docformatter = "==1.7.5" mypy = "==1.4.1" pylint = "==2.8.2" flake8 = "==3.9.2" @@ -60,6 +60,7 @@ pytest-watch = "==4.2.0" types-requests = "==2.27.7" py-spy = "==0.3.14" ruff = "==0.0.272" +virtualenv = "==20.21.0" [tool.isort] line_length = 88 @@ -88,8 +89,8 @@ plugins = "numpy.typing.mypy_plugin" [tool.pylint."MESSAGES CONTROL"] disable = "bad-continuation,duplicate-code,too-few-public-methods,useless-import-alias,no-self-use,too-many-locals,too-many-instance-attributes" good-names = "i,j,k,_,x,y,X,Y" -signature-mutators="hydra.main.main" -generated-members="numpy.*, torch.*" +signature-mutators = "hydra.main.main" +generated-members = "numpy.*, torch.*" [[tool.mypy.overrides]] module = [ diff --git a/baselines/depthfl/depthfl/models.py b/baselines/depthfl/depthfl/models.py index df3eebf9f9ce..5ce5eedb360e 100644 --- a/baselines/depthfl/depthfl/models.py +++ b/baselines/depthfl/depthfl/models.py @@ -1,6 +1,5 @@ """ResNet18 model architecutre, training, and testing functions for CIFAR100.""" - from typing import List, Tuple import torch diff --git a/baselines/depthfl/pyproject.toml b/baselines/depthfl/pyproject.toml index 2f928c2d3553..e59d16fef88e 100644 --- a/baselines/depthfl/pyproject.toml +++ b/baselines/depthfl/pyproject.toml @@ -9,9 +9,9 @@ description = "DepthFL: Depthwise Federated Learning for Heterogeneous Clients" license = "Apache-2.0" authors = ["Minjae Kim "] readme = "README.md" -homepage = "https://flower.dev" +homepage = "https://flower.ai" repository = "https://github.com/adap/flower" -documentation = "https://flower.dev" +documentation = "https://flower.ai" classifiers = [ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", @@ -37,18 +37,17 @@ classifiers = [ ] [tool.poetry.dependencies] -python = ">=3.10.0, <3.11.0" +python = ">=3.10.0, <3.11.0" flwr = { extras = ["simulation"], version = "1.5.0" } hydra-core = "1.3.2" # don't change this matplotlib = "3.7.1" -torch = { url = "https://download.pytorch.org/whl/cu116/torch-1.13.1%2Bcu116-cp310-cp310-linux_x86_64.whl"} -torchvision = { url = "https://download.pytorch.org/whl/cu116/torchvision-0.14.1%2Bcu116-cp310-cp310-linux_x86_64.whl"} - +torch = { url = "https://download.pytorch.org/whl/cu116/torch-1.13.1%2Bcu116-cp310-cp310-linux_x86_64.whl" } +torchvision = { url = "https://download.pytorch.org/whl/cu116/torchvision-0.14.1%2Bcu116-cp310-cp310-linux_x86_64.whl" } [tool.poetry.dev-dependencies] -isort = "==5.11.5" -black = "==23.1.0" -docformatter = "==1.5.1" +isort = "==5.13.2" +black = "==24.2.0" +docformatter = "==1.7.5" mypy = "==1.4.1" pylint = "==2.8.2" flake8 = "==3.9.2" @@ -56,6 +55,7 @@ pytest = "==6.2.4" pytest-watch = "==4.2.0" ruff = "==0.0.272" types-requests = "==2.27.7" +virtualenv = "==20.21.0" [tool.isort] line_length = 88 @@ -84,10 +84,10 @@ plugins = "numpy.typing.mypy_plugin" [tool.pylint."MESSAGES CONTROL"] disable = "bad-continuation,duplicate-code,too-few-public-methods,useless-import-alias" good-names = "i,j,k,_,x,y,X,Y" -signature-mutators="hydra.main.main" +signature-mutators = "hydra.main.main" [tool.pylint.typecheck] -generated-members="numpy.*, torch.*, tensorflow.*" +generated-members = "numpy.*, torch.*, tensorflow.*" [[tool.mypy.overrides]] module = [ diff --git a/baselines/doc/source/_templates/base.html b/baselines/doc/source/_templates/base.html index 1cee99053fbe..cc171be9e6b0 100644 --- a/baselines/doc/source/_templates/base.html +++ b/baselines/doc/source/_templates/base.html @@ -5,7 +5,7 @@ - + {%- if metatags %}{{ metatags }}{% endif -%} @@ -99,6 +99,6 @@ {%- endblock -%} {%- endblock scripts -%} - + diff --git a/baselines/doc/source/conf.py b/baselines/doc/source/conf.py index dabd421c61cf..dad8650cddaa 100644 --- a/baselines/doc/source/conf.py +++ b/baselines/doc/source/conf.py @@ -85,7 +85,7 @@ html_title = f"Flower Baselines {release}" html_logo = "_static/flower-logo.png" html_favicon = "_static/favicon.ico" -html_baseurl = "https://flower.dev/docs/baselines/" +html_baseurl = "https://flower.ai/docs/baselines/" html_theme_options = { # diff --git a/baselines/doc/source/index.rst b/baselines/doc/source/index.rst index 335cfacef1ab..3a19e74b891e 100644 --- a/baselines/doc/source/index.rst +++ b/baselines/doc/source/index.rst @@ -1,7 +1,7 @@ Flower Baselines Documentation ============================== -Welcome to Flower Baselines' documentation. `Flower `_ is a friendly federated learning framework. +Welcome to Flower Baselines' documentation. `Flower `_ is a friendly federated learning framework. Join the Flower Community @@ -9,7 +9,7 @@ Join the Flower Community The Flower Community is growing quickly - we're a friendly group of researchers, engineers, students, professionals, academics, and other enthusiasts. -.. button-link:: https://flower.dev/join-slack +.. button-link:: https://flower.ai/join-slack :color: primary :shadow: diff --git a/baselines/fedavgm/pyproject.toml b/baselines/fedavgm/pyproject.toml index 298deafd8932..d222baa65b0e 100644 --- a/baselines/fedavgm/pyproject.toml +++ b/baselines/fedavgm/pyproject.toml @@ -3,15 +3,15 @@ requires = ["poetry-core>=1.4.0"] build-backend = "poetry.masonry.api" [tool.poetry] -name = "fedavgm" +name = "fedavgm" version = "1.0.0" description = "FedAvgM: Measuring the effects of non-identical data distribution for federated visual classification" license = "Apache-2.0" authors = ["Gustavo Bertoli"] readme = "README.md" -homepage = "https://flower.dev" +homepage = "https://flower.ai" repository = "https://github.com/adap/flower" -documentation = "https://flower.dev" +documentation = "https://flower.ai" classifiers = [ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", @@ -38,18 +38,17 @@ classifiers = [ [tool.poetry.dependencies] python = ">=3.9, <3.12.0" # changed! original baseline template uses >= 3.8.15 -flwr = "1.5.0" -ray = "2.6.3" +flwr = { extras = ["simulation"], version = "1.5.0" } hydra-core = "1.3.2" # don't change this cython = "^3.0.0" -tensorflow = "2.10" +tensorflow = "2.11.1" numpy = "1.25.2" matplotlib = "^3.7.2" [tool.poetry.dev-dependencies] -isort = "==5.11.5" -black = "==23.1.0" -docformatter = "==1.5.1" +isort = "==5.13.2" +black = "==24.2.0" +docformatter = "==1.7.5" mypy = "==1.4.1" pylint = "==2.8.2" flake8 = "==3.9.2" @@ -57,6 +56,7 @@ pytest = "==6.2.4" pytest-watch = "==4.2.0" ruff = "==0.0.272" types-requests = "==2.27.7" +virtualenv = "==20.21.0" [tool.isort] line_length = 88 @@ -85,7 +85,7 @@ plugins = "numpy.typing.mypy_plugin" [tool.pylint."MESSAGES CONTROL"] disable = "bad-continuation,duplicate-code,too-few-public-methods,useless-import-alias" good-names = "i,j,k,_,x,y,X,Y" -signature-mutators="hydra.main.main" +signature-mutators = "hydra.main.main" [[tool.mypy.overrides]] module = [ diff --git a/baselines/fedbn/README.md b/baselines/fedbn/README.md index 4b271bd49851..d50c6f5bb605 100644 --- a/baselines/fedbn/README.md +++ b/baselines/fedbn/README.md @@ -34,37 +34,37 @@ dataset: [MNIST, MNIST-M, SVHN, USPS, SynthDigits] **Model:** A six-layer CNN with 14,219,210 parameters following the structure described in appendix D.2. -**Dataset:** This baseline makes use of the pre-processed partitions created and open source by the authors of the FedBN paper. You can read more about how those were created [here](https://github.com/med-air/FedBN). Follow the steps below in the `Environment Setup` section to download them. +**Dataset:** This baseline makes use of the pre-processed partitions created and open source by the authors of the FedBN paper. You can read more about how those were created [here](https://github.com/med-air/FedBN). Follow the steps below in the `Environment Setup` section to download them. A more detailed explanation of the datasets is given in the following table. -| | MNIST | MNIST-M | SVHN | USPS | SynthDigits | -|--- |--- |--- |--- |--- |--- | -| data type| handwritten digits| MNIST modification randomly colored with colored patches| Street view house numbers | handwritten digits from envelopes by the U.S. Postal Service | Syntehtic digits Windows TM font varying the orientation, blur and stroke colors | -| color | greyscale | RGB | RGB | greyscale | RGB | -| pixelsize | 28x28 | 28 x 28 | 32 x32 | 16 x16 | 32 x32 | -| labels | 0-9 | 0-9 | 1-10 | 0-9 | 1-10 | -| number of trainset | 60.000 | 60.000 | 73.257 | 9,298 | 50.000 | -| number of testset| 10.000 | 10.000 | 26.032 | - | - | -| image shape | (28,28) | (28,28,3) | (32,32,3) | (16,16) | (32,32,3) | +| | MNIST | MNIST-M | SVHN | USPS | SynthDigits | +| ------------------ | ------------------ | -------------------------------------------------------- | ------------------------- | ------------------------------------------------------------ | -------------------------------------------------------------------------------- | +| data type | handwritten digits | MNIST modification randomly colored with colored patches | Street view house numbers | handwritten digits from envelopes by the U.S. Postal Service | Syntehtic digits Windows TM font varying the orientation, blur and stroke colors | +| color | greyscale | RGB | RGB | greyscale | RGB | +| pixelsize | 28x28 | 28 x 28 | 32 x32 | 16 x16 | 32 x32 | +| labels | 0-9 | 0-9 | 1-10 | 0-9 | 1-10 | +| number of trainset | 60.000 | 60.000 | 73.257 | 9,298 | 50.000 | +| number of testset | 10.000 | 10.000 | 26.032 | - | - | +| image shape | (28,28) | (28,28,3) | (32,32,3) | (16,16) | (32,32,3) | **Training Hyperparameters:** By default (i.e. if you don't override anything in the config) these main hyperparameters used are shown in the table below. For a complete list of hyperparameters, please refer to the config files in `fedbn/conf`. -| Description | Value | -| ----------- | ----- | -| rounds | 10 | -| num_clients | 5 | -| strategy_fraction_fit | 1.0 | -| strategy.fraction_evaluate | 0.0 | -| training samples per client| 743 | -| client.l_r | 10E-2 | -| local epochs | 1 | -| loss | cross entropy loss | -| optimizer | SGD | -| client_resources.num_cpu | 2 | -| client_resources.num_gpus | 0.0 | +| Description | Value | +| --------------------------- | ------------------ | +| rounds | 10 | +| num_clients | 5 | +| strategy_fraction_fit | 1.0 | +| strategy.fraction_evaluate | 0.0 | +| training samples per client | 743 | +| client.l_r | 10E-2 | +| local epochs | 1 | +| loss | cross entropy loss | +| optimizer | SGD | +| client_resources.num_cpu | 2 | +| client_resources.num_gpus | 0.0 | ## Environment Setup @@ -93,7 +93,7 @@ cd data .. ## Running the Experiments -First, activate your environment via `poetry shell`. The commands below show how to run the experiments and modify some of its key hyperparameters via the cli. Each time you run an experiment, the log and results will be stored inside `outputs//
- 🇬🇧 - 🇫🇷 - 🇨🇳 + 🇬🇧 + 🇫🇷 + 🇨🇳
{% endif %} diff --git a/doc/source/_templates/sidebar/versioning.html b/doc/source/_templates/sidebar/versioning.html index dde7528d15e4..74f1cd8febb7 100644 --- a/doc/source/_templates/sidebar/versioning.html +++ b/doc/source/_templates/sidebar/versioning.html @@ -59,8 +59,8 @@ -
- +
+
diff --git a/doc/source/conf.py b/doc/source/conf.py index 259d8a988841..88cb5c05b1d8 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -123,25 +123,27 @@ # The full name is still at the top of the page add_module_names = False + def find_test_modules(package_path): """Go through the python files and exclude every *_test.py file.""" full_path_modules = [] for root, dirs, files in os.walk(package_path): for file in files: - if file.endswith('_test.py'): + if file.endswith("_test.py"): # Construct the module path relative to the package directory full_path = os.path.join(root, file) relative_path = os.path.relpath(full_path, package_path) # Convert file path to dotted module path - module_path = os.path.splitext(relative_path)[0].replace(os.sep, '.') + module_path = os.path.splitext(relative_path)[0].replace(os.sep, ".") full_path_modules.append(module_path) modules = [] for full_path_module in full_path_modules: - parts = full_path_module.split('.') + parts = full_path_module.split(".") for i in range(len(parts)): - modules.append('.'.join(parts[i:])) + modules.append(".".join(parts[i:])) return modules + # Stop from documenting the *_test.py files. # That's the only way to do that in autosummary (make the modules as mock_imports). autodoc_mock_imports = find_test_modules(os.path.abspath("../../src/py/flwr")) @@ -249,7 +251,7 @@ def find_test_modules(package_path): html_title = f"Flower Framework" html_logo = "_static/flower-logo.png" html_favicon = "_static/favicon.ico" -html_baseurl = "https://flower.dev/docs/framework/" +html_baseurl = "https://flower.ai/docs/framework/" html_theme_options = { # diff --git a/doc/source/contributor-explanation-architecture.rst b/doc/source/contributor-explanation-architecture.rst index 0e2ea1f6e66b..a20a84313118 100644 --- a/doc/source/contributor-explanation-architecture.rst +++ b/doc/source/contributor-explanation-architecture.rst @@ -4,7 +4,7 @@ Flower Architecture Edge Client Engine ------------------ -`Flower `_ core framework architecture with Edge Client Engine +`Flower `_ core framework architecture with Edge Client Engine .. figure:: _static/flower-architecture-ECE.png :width: 80 % @@ -12,7 +12,7 @@ Edge Client Engine Virtual Client Engine --------------------- -`Flower `_ core framework architecture with Virtual Client Engine +`Flower `_ core framework architecture with Virtual Client Engine .. figure:: _static/flower-architecture-VCE.png :width: 80 % @@ -20,7 +20,7 @@ Virtual Client Engine Virtual Client Engine and Edge Client Engine in the same workload ----------------------------------------------------------------- -`Flower `_ core framework architecture with both Virtual Client Engine and Edge Client Engine +`Flower `_ core framework architecture with both Virtual Client Engine and Edge Client Engine .. figure:: _static/flower-architecture.drawio.png :width: 80 % diff --git a/doc/source/contributor-how-to-build-docker-images.rst b/doc/source/contributor-how-to-build-docker-images.rst index 2c6c7a7ab986..5dead265bee2 100644 --- a/doc/source/contributor-how-to-build-docker-images.rst +++ b/doc/source/contributor-how-to-build-docker-images.rst @@ -17,7 +17,7 @@ Before we can start, we need to meet a few prerequisites in our local developmen #. Verify the Docker daemon is running. Please follow the first section on - `Run Flower using Docker `_ + :doc:`Run Flower using Docker ` which covers this step in more detail. Currently, Flower provides two images, a base image and a server image. There will also be a client diff --git a/doc/source/contributor-how-to-contribute-translations.rst b/doc/source/contributor-how-to-contribute-translations.rst index d97a2cb8c64f..ba59901cf1c4 100644 --- a/doc/source/contributor-how-to-contribute-translations.rst +++ b/doc/source/contributor-how-to-contribute-translations.rst @@ -2,13 +2,13 @@ Contribute translations ======================= Since `Flower 1.5 -`_ we +`_ we have introduced translations to our doc pages, but, as you might have noticed, the translations are often imperfect. If you speak languages other than English, you might be able to help us in our effort to make Federated Learning accessible to as many people as possible by contributing to those translations! This might also be a great opportunity for those wanting to become open source -contributors with little prerequistes. +contributors with little prerequisites. Our translation project is publicly available over on `Weblate `_, this where most @@ -44,7 +44,7 @@ This is what the interface looks like: .. image:: _static/weblate_interface.png -You input your translation in the textbox at the top and then, once you are +You input your translation in the text box at the top and then, once you are happy with it, you either press ``Save and continue`` (to save the translation and go to the next untranslated string), ``Save and stay`` (to save the translation and stay on the same page), ``Suggest`` (to add your translation to @@ -67,5 +67,5 @@ Add new languages ----------------- If you want to add a new language, you will first have to contact us, either on -`Slack `_, or by opening an issue on our `GitHub +`Slack `_, or by opening an issue on our `GitHub repo `_. diff --git a/doc/source/contributor-how-to-create-new-messages.rst b/doc/source/contributor-how-to-create-new-messages.rst index 24fa5f573158..5d9f4600361c 100644 --- a/doc/source/contributor-how-to-create-new-messages.rst +++ b/doc/source/contributor-how-to-create-new-messages.rst @@ -29,8 +29,8 @@ Let's now see what we need to implement in order to get this simple function bet Message Types for Protocol Buffers ---------------------------------- -The first thing we need to do is to define a message type for the RPC system in :code:`transport.proto`. -Note that we have to do it for both the request and response messages. For more details on the syntax of proto3, please see the `official documentation `_. +The first thing we need to do is to define a message type for the RPC system in :code:`transport.proto`. +Note that we have to do it for both the request and response messages. For more details on the syntax of proto3, please see the `official documentation `_. Within the :code:`ServerMessage` block: diff --git a/doc/source/contributor-how-to-develop-in-vscode-dev-containers.rst b/doc/source/contributor-how-to-develop-in-vscode-dev-containers.rst index 19d46c5753c6..c861457b6edc 100644 --- a/doc/source/contributor-how-to-develop-in-vscode-dev-containers.rst +++ b/doc/source/contributor-how-to-develop-in-vscode-dev-containers.rst @@ -8,7 +8,7 @@ When working on the Flower framework we want to ensure that all contributors use Workspace files are mounted from the local file system or copied or cloned into the container. Extensions are installed and run inside the container, where they have full access to the tools, platform, and file system. This means that you can seamlessly switch your entire development environment just by connecting to a different container. -Source: `Official VSCode documentation `_ +Source: `Official VSCode documentation `_ Getting started @@ -20,5 +20,5 @@ Now you should be good to go. When starting VSCode, it will ask you to run in th In some cases your setup might be more involved. For those cases consult the following sources: -* `Developing inside a Container `_ -* `Remote development in Containers `_ +* `Developing inside a Container `_ +* `Remote development in Containers `_ diff --git a/doc/source/contributor-ref-good-first-contributions.rst b/doc/source/contributor-ref-good-first-contributions.rst index 523a4679c6ef..2b8ce88413f5 100644 --- a/doc/source/contributor-ref-good-first-contributions.rst +++ b/doc/source/contributor-ref-good-first-contributions.rst @@ -14,7 +14,7 @@ Until the Flower core library matures it will be easier to get PR's accepted if they only touch non-core areas of the codebase. Good candidates to get started are: -- Documentation: What's missing? What could be expressed more clearly? +- Documentation: What's missing? What could be expressed more clearly? - Baselines: See below. - Examples: See below. @@ -22,11 +22,11 @@ are: Request for Flower Baselines ---------------------------- -If you are not familiar with Flower Baselines, you should probably check-out our `contributing guide for baselines `_. +If you are not familiar with Flower Baselines, you should probably check-out our `contributing guide for baselines `_. -You should then check out the open +You should then check out the open `issues `_ for baseline requests. -If you find a baseline that you'd like to work on and that has no assignes, feel free to assign it to yourself and start working on it! +If you find a baseline that you'd like to work on and that has no assignees, feel free to assign it to yourself and start working on it! Otherwise, if you don't find a baseline you'd like to work on, be sure to open a new issue with the baseline request template! diff --git a/doc/source/contributor-tutorial-contribute-on-github.rst b/doc/source/contributor-tutorial-contribute-on-github.rst index d409802897e4..6da81ce73662 100644 --- a/doc/source/contributor-tutorial-contribute-on-github.rst +++ b/doc/source/contributor-tutorial-contribute-on-github.rst @@ -3,8 +3,7 @@ Contribute on GitHub This guide is for people who want to get involved with Flower, but who are not used to contributing to GitHub projects. -If you're familiar with how contributing on GitHub works, you can directly checkout our -`getting started guide for contributors `_. +If you're familiar with how contributing on GitHub works, you can directly checkout our :doc:`getting started guide for contributors `. Setting up the repository @@ -12,21 +11,21 @@ Setting up the repository 1. **Create a GitHub account and setup Git** Git is a distributed version control tool. This allows for an entire codebase's history to be stored and every developer's machine. - It is a software that will need to be installed on your local machine, you can follow this `guide `_ to set it up. + It is a software that will need to be installed on your local machine, you can follow this `guide `_ to set it up. GitHub, itself, is a code hosting platform for version control and collaboration. It allows for everyone to collaborate and work from anywhere on remote repositories. - If you haven't already, you will need to create an account on `GitHub `_. + If you haven't already, you will need to create an account on `GitHub `_. - The idea behind the generic Git and GitHub workflow boils down to this: + The idea behind the generic Git and GitHub workflow boils down to this: you download code from a remote repository on GitHub, make changes locally and keep track of them using Git and then you upload your new history back to GitHub. 2. **Forking the Flower repository** - A fork is a personal copy of a GitHub repository. To create one for Flower, you must navigate to https://github.com/adap/flower (while connected to your GitHub account) + A fork is a personal copy of a GitHub repository. To create one for Flower, you must navigate to ``_ (while connected to your GitHub account) and click the ``Fork`` button situated on the top right of the page. .. image:: _static/fork_button.png - + You can change the name if you want, but this is not necessary as this version of Flower will be yours and will sit inside your own account (i.e., in your own list of repositories). Once created, you should see on the top left corner that you are looking at your own version of Flower. @@ -34,14 +33,14 @@ Setting up the repository 3. **Cloning your forked repository** The next step is to download the forked repository on your machine to be able to make changes to it. - On your forked repository page, you should first click on the ``Code`` button on the right, + On your forked repository page, you should first click on the ``Code`` button on the right, this will give you the ability to copy the HTTPS link of the repository. .. image:: _static/cloning_fork.png Once you copied the \, you can open a terminal on your machine, navigate to the place you want to download the repository to and type: - .. code-block:: shell + .. code-block:: shell $ git clone @@ -58,17 +57,17 @@ Setting up the repository To obtain it, we can do as previously mentioned by going to our fork repository on our GitHub account and copying the link. .. image:: _static/cloning_fork.png - + Once the \ is copied, we can type the following command in our terminal: .. code-block:: shell $ git remote add origin - + 5. **Add upstream** Now we will add an upstream address to our repository. - Still in the same directroy, we must run the following command: + Still in the same directory, we must run the following command: .. code-block:: shell @@ -76,10 +75,10 @@ Setting up the repository The following diagram visually explains what we did in the previous steps: - .. image:: _static/github_schema.png + .. image:: _static/github_schema.png - The upstream is the GitHub remote address of the parent repository (in this case Flower), - i.e. the one we eventually want to contribute to and therefore need an up-to-date history of. + The upstream is the GitHub remote address of the parent repository (in this case Flower), + i.e. the one we eventually want to contribute to and therefore need an up-to-date history of. The origin is just the GitHub remote address of the forked repository we created, i.e. the copy (fork) in our own account. To make sure our local version of the fork is up-to-date with the latest changes from the Flower repository, @@ -93,7 +92,7 @@ Setting up the repository Setting up the coding environment --------------------------------- -This can be achieved by following this `getting started guide for contributors`_ (note that you won't need to clone the repository). +This can be achieved by following this :doc:`getting started guide for contributors ` (note that you won't need to clone the repository). Once you are able to write code and test it, you can finally start making changes! @@ -113,9 +112,9 @@ And with Flower's repository: $ git pull upstream main 1. **Create a new branch** - To make the history cleaner and easier to work with, it is good practice to + To make the history cleaner and easier to work with, it is good practice to create a new branch for each feature/project that needs to be implemented. - + To do so, just run the following command inside the repository's directory: .. code-block:: shell @@ -137,7 +136,7 @@ And with Flower's repository: $ ./dev/test.sh # to test that your code can be accepted $ ./baselines/dev/format.sh # same as above but for code added to baselines $ ./baselines/dev/test.sh # same as above but for code added to baselines - + 4. **Stage changes** Before creating a commit that will update your history, you must specify to Git which files it needs to take into account. @@ -184,21 +183,21 @@ Creating and merging a pull request (PR) Once you click the ``Compare & pull request`` button, you should see something similar to this: .. image:: _static/creating_pr.png - + At the top you have an explanation of which branch will be merged where: .. image:: _static/merging_branch.png - + In this example you can see that the request is to merge the branch ``doc-fixes`` from my forked repository to branch ``main`` from the Flower repository. - The input box in the middle is there for you to describe what your PR does and to link it to existing issues. + The input box in the middle is there for you to describe what your PR does and to link it to existing issues. We have placed comments (that won't be rendered once the PR is opened) to guide you through the process. It is important to follow the instructions described in comments. For instance, in order to not break how our changelog system works, you should read the information above the ``Changelog entry`` section carefully. You can also checkout some examples and details in the :ref:`changelogentry` appendix. - At the bottom you will find the button to open the PR. This will notify reviewers that a new PR has been opened and + At the bottom you will find the button to open the PR. This will notify reviewers that a new PR has been opened and that they should look over it to merge or to request changes. If your PR is not yet ready for review, and you don't want to notify anyone, you have the option to create a draft pull request: @@ -218,7 +217,7 @@ Creating and merging a pull request (PR) Merging will be blocked if there are ongoing requested changes. .. image:: _static/changes_requested.png - + To resolve them, just push the necessary changes to the branch associated with the PR: .. image:: _static/make_changes.png @@ -256,36 +255,36 @@ Example of first contribution Problem ******* -For our documentation, we’ve started to use the `Diàtaxis framework `_. +For our documentation, we've started to use the `Diàtaxis framework `_. -Our “How to” guides should have titles that continue the sencence “How to …”, for example, “How to upgrade to Flower 1.0”. +Our "How to" guides should have titles that continue the sentence "How to …", for example, "How to upgrade to Flower 1.0". Most of our guides do not follow this new format yet, and changing their title is (unfortunately) more involved than one might think. -This issue is about changing the title of a doc from present continious to present simple. +This issue is about changing the title of a doc from present continuous to present simple. -Let's take the example of “Saving Progress” which we changed to “Save Progress”. Does this pass our check? +Let's take the example of "Saving Progress" which we changed to "Save Progress". Does this pass our check? -Before: ”How to saving progress” ❌ +Before: "How to saving progress" ❌ -After: ”How to save progress” ✅ +After: "How to save progress" ✅ Solution ******** -This is a tiny change, but it’ll allow us to test your end-to-end setup. After cloning and setting up the Flower repo, here’s what you should do: +This is a tiny change, but it'll allow us to test your end-to-end setup. After cloning and setting up the Flower repo, here's what you should do: - Find the source file in ``doc/source`` - Make the change in the ``.rst`` file (beware, the dashes under the title should be the same length as the title itself) -- Build the docs and check the result: ``_ +- Build the docs and `check the result `_ Rename file ::::::::::: -You might have noticed that the file name still reflects the old wording. +You might have noticed that the file name still reflects the old wording. If we just change the file, then we break all existing links to it - it is **very important** to avoid that, breaking links can harm our search engine ranking. -Here’s how to change the file name: +Here's how to change the file name: - Change the file name to ``save-progress.rst`` - Add a redirect rule to ``doc/source/conf.py`` @@ -295,7 +294,7 @@ This will cause a redirect from ``saving-progress.html`` to ``save-progress.html Apply changes in the index file ::::::::::::::::::::::::::::::: -For the lateral navigation bar to work properly, it is very important to update the ``index.rst`` file as well. +For the lateral navigation bar to work properly, it is very important to update the ``index.rst`` file as well. This is where we define the whole arborescence of the navbar. - Find and modify the file name in ``index.rst`` @@ -303,7 +302,7 @@ This is where we define the whole arborescence of the navbar. Open PR ::::::: -- Commit the changes (commit messages are always imperative: “Do something”, in this case “Change …”) +- Commit the changes (commit messages are always imperative: "Do something", in this case "Change …") - Push the changes to your fork - Open a PR (as shown above) - Wait for it to be approved! @@ -343,7 +342,7 @@ Next steps Once you have made your first PR, and want to contribute more, be sure to check out the following : -- `Good first contributions `_, where you should particularly look into the :code:`baselines` contributions. +- :doc:`Good first contributions `, where you should particularly look into the :code:`baselines` contributions. Appendix @@ -358,10 +357,10 @@ When opening a new PR, inside its description, there should be a ``Changelog ent Above this header you should see the following comment that explains how to write your changelog entry: - Inside the following 'Changelog entry' section, + Inside the following 'Changelog entry' section, you should put the description of your changes that will be added to the changelog alongside your PR title. - If the section is completely empty (without any token) or non-existant, + If the section is completely empty (without any token) or non-existent, the changelog will just contain the title of the PR for the changelog entry, without any description. If the section contains some text other than tokens, it will use it to add a description to the change. diff --git a/doc/source/contributor-tutorial-get-started-as-a-contributor.rst b/doc/source/contributor-tutorial-get-started-as-a-contributor.rst index 72c6df5fdbc7..01810c7244d3 100644 --- a/doc/source/contributor-tutorial-get-started-as-a-contributor.rst +++ b/doc/source/contributor-tutorial-get-started-as-a-contributor.rst @@ -11,7 +11,7 @@ Prerequisites Flower uses :code:`pyproject.toml` to manage dependencies and configure development tools (the ones which support it). Poetry is a build tool which -supports `PEP 517 `_. +supports `PEP 517 `_. Developer Machine Setup @@ -27,7 +27,7 @@ For macOS * Install `homebrew `_. Don't forget the post-installation actions to add `brew` to your PATH. * Install `xz` (to install different Python versions) and `pandoc` to build the docs:: - + $ brew install xz pandoc For Ubuntu @@ -54,7 +54,7 @@ GitHub:: * If you don't have :code:`pyenv` installed, the following script that will install it, set it up, and create the virtual environment (with :code:`Python 3.8.17` by default):: $ ./dev/setup-defaults.sh # once completed, run the bootstrap script - + * If you already have :code:`pyenv` installed (along with the :code:`pyenv-virtualenv` plugin), you can use the following convenience script (with :code:`Python 3.8.17` by default):: $ ./dev/venv-create.sh # once completed, run the `bootstrap.sh` script diff --git a/doc/source/example-fedbn-pytorch-from-centralized-to-federated.rst b/doc/source/example-fedbn-pytorch-from-centralized-to-federated.rst index 5ebaa337dde8..0139f3b8dc31 100644 --- a/doc/source/example-fedbn-pytorch-from-centralized-to-federated.rst +++ b/doc/source/example-fedbn-pytorch-from-centralized-to-federated.rst @@ -3,11 +3,11 @@ Example: FedBN in PyTorch - From Centralized To Federated This tutorial will show you how to use Flower to build a federated version of an existing machine learning workload with `FedBN `_, a federated training strategy designed for non-iid data. We are using PyTorch to train a Convolutional Neural Network(with Batch Normalization layers) on the CIFAR-10 dataset. -When applying FedBN, only few changes needed compared to `Example: PyTorch - From Centralized To Federated `_. +When applying FedBN, only few changes needed compared to :doc:`Example: PyTorch - From Centralized To Federated `. Centralized Training -------------------- -All files are revised based on `Example: PyTorch - From Centralized To Federated `_. +All files are revised based on :doc:`Example: PyTorch - From Centralized To Federated `. The only thing to do is modifying the file called :code:`cifar.py`, revised part is shown below: The model architecture defined in class Net() is added with Batch Normalization layers accordingly. @@ -45,13 +45,13 @@ You can now run your machine learning workload: python3 cifar.py So far this should all look fairly familiar if you've used PyTorch before. -Let's take the next step and use what we've built to create a federated learning system within FedBN, the sytstem consists of one server and two clients. +Let's take the next step and use what we've built to create a federated learning system within FedBN, the system consists of one server and two clients. Federated Training ------------------ -If you have read `Example: PyTorch - From Centralized To Federated `_, the following parts are easy to follow, onyl :code:`get_parameters` and :code:`set_parameters` function in :code:`client.py` needed to revise. -If not, please read the `Example: PyTorch - From Centralized To Federated `_. first. +If you have read :doc:`Example: PyTorch - From Centralized To Federated `, the following parts are easy to follow, only :code:`get_parameters` and :code:`set_parameters` function in :code:`client.py` needed to revise. +If not, please read the :doc:`Example: PyTorch - From Centralized To Federated `. first. Our example consists of one *server* and two *clients*. In FedBN, :code:`server.py` keeps unchanged, we can start the server directly. @@ -66,7 +66,7 @@ Finally, we will revise our *client* logic by changing :code:`get_parameters` an class CifarClient(fl.client.NumPyClient): """Flower client implementing CIFAR-10 image classification using PyTorch.""" - + ... def get_parameters(self, config) -> List[np.ndarray]: @@ -79,7 +79,7 @@ Finally, we will revise our *client* logic by changing :code:`get_parameters` an params_dict = zip(keys, parameters) state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) self.model.load_state_dict(state_dict, strict=False) - + ... Now, you can now open two additional terminal windows and run diff --git a/doc/source/example-walkthrough-pytorch-mnist.rst b/doc/source/example-walkthrough-pytorch-mnist.rst index ab311813f5de..8717c196f043 100644 --- a/doc/source/example-walkthrough-pytorch-mnist.rst +++ b/doc/source/example-walkthrough-pytorch-mnist.rst @@ -1,11 +1,11 @@ Example: Walk-Through PyTorch & MNIST ===================================== -In this tutorial we will learn, how to train a Convolutional Neural Network on MNIST using Flower and PyTorch. +In this tutorial we will learn, how to train a Convolutional Neural Network on MNIST using Flower and PyTorch. -Our example consists of one *server* and two *clients* all having the same model. +Our example consists of one *server* and two *clients* all having the same model. -*Clients* are responsible for generating individual weight-updates for the model based on their local datasets. +*Clients* are responsible for generating individual weight-updates for the model based on their local datasets. These updates are then sent to the *server* which will aggregate them to produce a better model. Finally, the *server* sends this improved version of the model back to each *client*. A complete cycle of weight updates is called a *round*. @@ -15,7 +15,7 @@ Now that we have a rough idea of what is going on, let's get started. We first n $ pip install flwr -Since we want to use PyTorch to solve a computer vision task, let's go ahead an install PyTorch and the **torchvision** library: +Since we want to use PyTorch to solve a computer vision task, let's go ahead an install PyTorch and the **torchvision** library: .. code-block:: shell @@ -32,51 +32,51 @@ Go ahead and launch on a terminal the *run-server.sh* script first as follows: .. code-block:: shell - $ bash ./run-server.sh + $ bash ./run-server.sh -Now that the server is up and running, go ahead and launch the clients. +Now that the server is up and running, go ahead and launch the clients. .. code-block:: shell - $ bash ./run-clients.sh + $ bash ./run-clients.sh Et voilà! You should be seeing the training procedure and, after a few iterations, the test accuracy for each client. .. code-block:: shell - Train Epoch: 10 [30000/30016 (100%)] Loss: 0.007014 - - Train Epoch: 10 [30000/30016 (100%)] Loss: 0.000403 - - Train Epoch: 11 [30000/30016 (100%)] Loss: 0.001280 - - Train Epoch: 11 [30000/30016 (100%)] Loss: 0.000641 - - Train Epoch: 12 [30000/30016 (100%)] Loss: 0.006784 - - Train Epoch: 12 [30000/30016 (100%)] Loss: 0.007134 - - Client 1 - Evaluate on 5000 samples: Average loss: 0.0290, Accuracy: 99.16% - + Train Epoch: 10 [30000/30016 (100%)] Loss: 0.007014 + + Train Epoch: 10 [30000/30016 (100%)] Loss: 0.000403 + + Train Epoch: 11 [30000/30016 (100%)] Loss: 0.001280 + + Train Epoch: 11 [30000/30016 (100%)] Loss: 0.000641 + + Train Epoch: 12 [30000/30016 (100%)] Loss: 0.006784 + + Train Epoch: 12 [30000/30016 (100%)] Loss: 0.007134 + + Client 1 - Evaluate on 5000 samples: Average loss: 0.0290, Accuracy: 99.16% + Client 0 - Evaluate on 5000 samples: Average loss: 0.0328, Accuracy: 99.14% -Now, let's see what is really happening inside. +Now, let's see what is really happening inside. Flower Server ------------- Inside the server helper script *run-server.sh* you will find the following code that basically runs the :code:`server.py` -.. code-block:: bash +.. code-block:: bash python -m flwr_example.quickstart-pytorch.server We can go a bit deeper and see that :code:`server.py` simply launches a server that will coordinate three rounds of training. -Flower Servers are very customizable, but for simple workloads, we can start a server using the :ref:`start_server ` function and leave all the configuration possibilities at their default values, as seen below. +Flower Servers are very customizable, but for simple workloads, we can start a server using the `start_server `_ function and leave all the configuration possibilities at their default values, as seen below. .. code-block:: python @@ -90,18 +90,18 @@ Flower Client Next, let's take a look at the *run-clients.sh* file. You will see that it contains the main loop that starts a set of *clients*. -.. code-block:: bash +.. code-block:: bash python -m flwr_example.quickstart-pytorch.client \ --cid=$i \ --server_address=$SERVER_ADDRESS \ - --nb_clients=$NUM_CLIENTS + --nb_clients=$NUM_CLIENTS * **cid**: is the client ID. It is an integer that uniquely identifies client identifier. -* **sever_address**: String that identifies IP and port of the server. +* **sever_address**: String that identifies IP and port of the server. * **nb_clients**: This defines the number of clients being created. This piece of information is not required by the client, but it helps us partition the original MNIST dataset to make sure that every client is working on unique subsets of both *training* and *test* sets. -Again, we can go deeper and look inside :code:`flwr_example/quickstart-pytorch/client.py`. +Again, we can go deeper and look inside :code:`flwr_example/quickstart-pytorch/client.py`. After going through the argument parsing code at the beginning of our :code:`main` function, you will find a call to :code:`mnist.load_data`. This function is responsible for partitioning the original MNIST datasets (*training* and *test*) and returning a :code:`torch.utils.data.DataLoader` s for each of them. We then instantiate a :code:`PytorchMNISTClient` object with our client ID, our DataLoaders, the number of epochs in each round, and which device we want to use for training (CPU or GPU). @@ -152,7 +152,7 @@ Now, let's look closely into the :code:`PytorchMNISTClient` inside :code:`flwr_e Parameters ---------- - weights: fl.common.NDArrays + weights: fl.common.NDArrays Weights received by the server and set to local model @@ -179,8 +179,8 @@ Now, let's look closely into the :code:`PytorchMNISTClient` inside :code:`flwr_e Parameters ---------- - ins: fl.common.FitIns - Parameters sent by the server to be used during training. + ins: fl.common.FitIns + Parameters sent by the server to be used during training. Returns ------- @@ -214,9 +214,9 @@ Now, let's look closely into the :code:`PytorchMNISTClient` inside :code:`flwr_e Parameters ---------- - ins: fl.common.EvaluateIns - Parameters sent by the server to be used during testing. - + ins: fl.common.EvaluateIns + Parameters sent by the server to be used during testing. + Returns ------- @@ -262,9 +262,9 @@ The code for the CNN is available under :code:`quickstart-pytorch.mnist` and it Parameters ---------- - x: Tensor + x: Tensor Mini-batch of shape (N,28,28) containing images from MNIST dataset. - + Returns ------- @@ -287,7 +287,7 @@ The code for the CNN is available under :code:`quickstart-pytorch.mnist` and it return output -The second thing to notice is that :code:`PytorchMNISTClient` class inherits from the :code:`fl.client.Client`, and hence it must implement the following methods: +The second thing to notice is that :code:`PytorchMNISTClient` class inherits from the :code:`fl.client.Client`, and hence it must implement the following methods: .. code-block:: python @@ -312,7 +312,7 @@ The second thing to notice is that :code:`PytorchMNISTClient` class inherits fro """Evaluate the provided weights using the locally held dataset.""" -When comparing the abstract class to its derived class :code:`PytorchMNISTClient` you will notice that :code:`fit` calls a :code:`train` function and that :code:`evaluate` calls a :code:`test`: function. +When comparing the abstract class to its derived class :code:`PytorchMNISTClient` you will notice that :code:`fit` calls a :code:`train` function and that :code:`evaluate` calls a :code:`test`: function. These functions can both be found inside the same :code:`quickstart-pytorch.mnist` module: @@ -330,14 +330,14 @@ These functions can both be found inside the same :code:`quickstart-pytorch.mnis ---------- model: torch.nn.ModuleList Neural network model used in this example. - + train_loader: torch.utils.data.DataLoader DataLoader used in traning. - - epochs: int - Number of epochs to run in each round. - - device: torch.device + + epochs: int + Number of epochs to run in each round. + + device: torch.device (Default value = torch.device("cpu")) Device where the network will be trained within a client. @@ -399,10 +399,10 @@ These functions can both be found inside the same :code:`quickstart-pytorch.mnis ---------- model: torch.nn.ModuleList : Neural network model used in this example. - + test_loader: torch.utils.data.DataLoader : DataLoader used in test. - + device: torch.device : (Default value = torch.device("cpu")) Device where the network will be tested within a client. @@ -435,19 +435,19 @@ These functions can both be found inside the same :code:`quickstart-pytorch.mnis Observe that these functions encapsulate regular training and test loops and provide :code:`fit` and :code:`evaluate` with final statistics for each round. -You could substitute them with your custom train and test loops and change the network architecture, and the entire example would still work flawlessly. -As a matter of fact, why not try and modify the code to an example of your liking? +You could substitute them with your custom train and test loops and change the network architecture, and the entire example would still work flawlessly. +As a matter of fact, why not try and modify the code to an example of your liking? Give It a Try ------------- -Looking through the quickstart code description above will have given a good understanding of how *clients* and *servers* work in Flower, how to run a simple experiment, and the internals of a client wrapper. +Looking through the quickstart code description above will have given a good understanding of how *clients* and *servers* work in Flower, how to run a simple experiment, and the internals of a client wrapper. Here are a few things you could try on your own and get more experience with Flower: - Try and change :code:`PytorchMNISTClient` so it can accept different architectures. - Modify the :code:`train` function so that it accepts different optimizers - Modify the :code:`test` function so that it proves not only the top-1 (regular accuracy) but also the top-5 accuracy? -- Go larger! Try to adapt the code to larger images and datasets. Why not try training on ImageNet with a ResNet-50? +- Go larger! Try to adapt the code to larger images and datasets. Why not try training on ImageNet with a ResNet-50? You are ready now. Enjoy learning in a federated way! diff --git a/doc/source/how-to-configure-clients.rst b/doc/source/how-to-configure-clients.rst index 26c132125ccf..ff0a2f4033df 100644 --- a/doc/source/how-to-configure-clients.rst +++ b/doc/source/how-to-configure-clients.rst @@ -13,7 +13,7 @@ Configuration values are represented as a dictionary with ``str`` keys and value config_dict = { "dropout": True, # str key, bool value "learning_rate": 0.01, # str key, float value - "batch_size": 32, # str key, int value + "batch_size": 32, # str key, int value "optimizer": "sgd", # str key, str value } @@ -56,7 +56,7 @@ To make the built-in strategies use this function, we can pass it to ``FedAvg`` One the client side, we receive the configuration dictionary in ``fit``: .. code-block:: python - + class FlowerClient(flwr.client.NumPyClient): def fit(parameters, config): print(config["batch_size"]) # Prints `32` @@ -86,7 +86,7 @@ Configuring individual clients In some cases, it is necessary to send different configuration values to different clients. -This can be achieved by customizing an existing strategy or by `implementing a custom strategy from scratch `_. Here's a nonsensical example that customizes :code:`FedAvg` by adding a custom ``"hello": "world"`` configuration key/value pair to the config dict of a *single client* (only the first client in the list, the other clients in this round to not receive this "special" config value): +This can be achieved by customizing an existing strategy or by :doc:`implementing a custom strategy from scratch `. Here's a nonsensical example that customizes :code:`FedAvg` by adding a custom ``"hello": "world"`` configuration key/value pair to the config dict of a *single client* (only the first client in the list, the other clients in this round to not receive this "special" config value): .. code-block:: python diff --git a/doc/source/how-to-install-flower.rst b/doc/source/how-to-install-flower.rst index ff3dbb605846..aebe5f7316de 100644 --- a/doc/source/how-to-install-flower.rst +++ b/doc/source/how-to-install-flower.rst @@ -57,7 +57,7 @@ Advanced installation options Install via Docker ~~~~~~~~~~~~~~~~~~ -`How to run Flower using Docker `_ +:doc:`How to run Flower using Docker ` Install pre-release ~~~~~~~~~~~~~~~~~~~ diff --git a/doc/source/how-to-monitor-simulation.rst b/doc/source/how-to-monitor-simulation.rst index 740004914eed..61a3be68deec 100644 --- a/doc/source/how-to-monitor-simulation.rst +++ b/doc/source/how-to-monitor-simulation.rst @@ -231,6 +231,6 @@ A: Either the simulation has already finished, or you still need to start Promet Resources --------- -Ray Dashboard: ``_ +Ray Dashboard: ``_ -Ray Metrics: ``_ +Ray Metrics: ``_ diff --git a/doc/source/how-to-run-flower-using-docker.rst b/doc/source/how-to-run-flower-using-docker.rst index 40df1ffcb63c..ed034c820142 100644 --- a/doc/source/how-to-run-flower-using-docker.rst +++ b/doc/source/how-to-run-flower-using-docker.rst @@ -54,7 +54,7 @@ to the Flower server. Here, we are passing the flag ``--insecure``. The ``--insecure`` flag enables insecure communication (using HTTP, not HTTPS) and should only be used for testing purposes. We strongly recommend enabling - `SSL `_ + `SSL `_ when deploying to a production environment. You can use ``--help`` to view all available flags that the server supports: @@ -90,7 +90,7 @@ To enable SSL, you will need a CA certificate, a server certificate and a server .. note:: For testing purposes, you can generate your own self-signed certificates. The - `Enable SSL connections `_ + `Enable SSL connections `_ page contains a section that will guide you through the process. Assuming all files we need are in the local ``certificates`` directory, we can use the flag diff --git a/doc/source/how-to-run-simulations.rst b/doc/source/how-to-run-simulations.rst index 6e0520a79bf5..d1dcb511ed51 100644 --- a/doc/source/how-to-run-simulations.rst +++ b/doc/source/how-to-run-simulations.rst @@ -29,7 +29,7 @@ Running Flower simulations still require you to define your client class, a stra def client_fn(cid: str): # Return a standard Flower client - return MyFlowerClient() + return MyFlowerClient().to_client() # Launch the simulation hist = fl.simulation.start_simulation( diff --git a/doc/source/how-to-upgrade-to-flower-1.0.rst b/doc/source/how-to-upgrade-to-flower-1.0.rst index fd380e95d69c..c4429d61d0a9 100644 --- a/doc/source/how-to-upgrade-to-flower-1.0.rst +++ b/doc/source/how-to-upgrade-to-flower-1.0.rst @@ -50,7 +50,7 @@ Strategies / ``start_server`` / ``start_simulation`` - Replace ``num_rounds=1`` in ``start_simulation`` with the new ``config=ServerConfig(...)`` (see previous item) - Remove ``force_final_distributed_eval`` parameter from calls to ``start_server``. Distributed evaluation on all clients can be enabled by configuring the strategy to sample all clients for evaluation after the last round of training. - Rename parameter/ndarray conversion functions: - + - ``parameters_to_weights`` --> ``parameters_to_ndarrays`` - ``weights_to_parameters`` --> ``ndarrays_to_parameters`` @@ -88,4 +88,4 @@ Along with the necessary changes above, there are a number of potential improvem Further help ------------ -Most official `Flower code examples `_ are already updated to Flower 1.0, they can serve as a reference for using the Flower 1.0 API. If there are further questionsm, `join the Flower Slack `_ and use the channgel ``#questions``. +Most official `Flower code examples `_ are already updated to Flower 1.0, they can serve as a reference for using the Flower 1.0 API. If there are further questionsm, `join the Flower Slack `_ and use the channgel ``#questions``. diff --git a/doc/source/how-to-use-built-in-mods.rst b/doc/source/how-to-use-built-in-mods.rst index af7102de9d0b..341139175074 100644 --- a/doc/source/how-to-use-built-in-mods.rst +++ b/doc/source/how-to-use-built-in-mods.rst @@ -86,4 +86,4 @@ Conclusion By following this guide, you have learned how to effectively use mods to enhance your ``ClientApp``'s functionality. Remember that the order of mods is crucial and affects how the input and output are processed. -Enjoy building more robust and flexible ``ClientApp``s with mods! +Enjoy building a more robust and flexible ``ClientApp`` with mods! diff --git a/doc/source/index.rst b/doc/source/index.rst index 7e2b4052bee6..ea52a9421b61 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -4,7 +4,7 @@ Flower Framework Documentation .. meta:: :description: Check out the documentation of the main Flower Framework enabling easy Python development for Federated Learning. -Welcome to Flower's documentation. `Flower `_ is a friendly federated learning framework. +Welcome to Flower's documentation. `Flower `_ is a friendly federated learning framework. Join the Flower Community @@ -12,7 +12,7 @@ Join the Flower Community The Flower Community is growing quickly - we're a friendly group of researchers, engineers, students, professionals, academics, and other enthusiasts. -.. button-link:: https://flower.dev/join-slack +.. button-link:: https://flower.ai/join-slack :color: primary :shadow: diff --git a/doc/source/ref-api-cli.rst b/doc/source/ref-api-cli.rst index c0e8940061fc..63579143755d 100644 --- a/doc/source/ref-api-cli.rst +++ b/doc/source/ref-api-cli.rst @@ -31,22 +31,22 @@ flower-fleet-api :func: _parse_args_run_fleet_api :prog: flower-fleet-api -.. .. _flower-client-app-apiref: +.. _flower-client-app-apiref: -.. flower-client-app -.. ~~~~~~~~~~~~~~~~~ +flower-client-app +~~~~~~~~~~~~~~~~~ -.. .. argparse:: -.. :filename: flwr.client -.. :func: _parse_args_run_client_app -.. :prog: flower-client-app +.. argparse:: + :module: flwr.client.app + :func: _parse_args_run_client_app + :prog: flower-client-app -.. .. _flower-server-app-apiref: +.. _flower-server-app-apiref: -.. flower-server-app -.. ~~~~~~~~~~~~~~~~~ +flower-server-app +~~~~~~~~~~~~~~~~~ -.. .. argparse:: -.. :filename: flwr.server -.. :func: _parse_args_run_server_app -.. :prog: flower-server-app +.. argparse:: + :module: flwr.server.run_serverapp + :func: _parse_args_run_server_app + :prog: flower-server-app diff --git a/doc/source/ref-changelog.md b/doc/source/ref-changelog.md index 78d1e0e491a4..54092e15a564 100644 --- a/doc/source/ref-changelog.md +++ b/doc/source/ref-changelog.md @@ -42,7 +42,7 @@ We would like to give our special thanks to all the contributors who made the ne - **Introduce Docker image for Flower server** ([#2700](https://github.com/adap/flower/pull/2700), [#2688](https://github.com/adap/flower/pull/2688), [#2705](https://github.com/adap/flower/pull/2705), [#2695](https://github.com/adap/flower/pull/2695), [#2747](https://github.com/adap/flower/pull/2747), [#2746](https://github.com/adap/flower/pull/2746), [#2680](https://github.com/adap/flower/pull/2680), [#2682](https://github.com/adap/flower/pull/2682), [#2701](https://github.com/adap/flower/pull/2701)) - The Flower server can now be run using an official Docker image. A new how-to guide explains [how to run Flower using Docker](https://flower.dev/docs/framework/how-to-run-flower-using-docker.html). An official Flower client Docker image will follow. + The Flower server can now be run using an official Docker image. A new how-to guide explains [how to run Flower using Docker](https://flower.ai/docs/framework/how-to-run-flower-using-docker.html). An official Flower client Docker image will follow. - **Introduce** `flower-via-docker-compose` **example** ([#2626](https://github.com/adap/flower/pull/2626)) @@ -52,7 +52,7 @@ We would like to give our special thanks to all the contributors who made the ne - **Update code examples to use Flower Datasets** ([#2450](https://github.com/adap/flower/pull/2450), [#2456](https://github.com/adap/flower/pull/2456), [#2318](https://github.com/adap/flower/pull/2318), [#2712](https://github.com/adap/flower/pull/2712)) - Several code examples were updated to use [Flower Datasets](https://flower.dev/docs/datasets/). + Several code examples were updated to use [Flower Datasets](https://flower.ai/docs/datasets/). - **General updates to Flower Examples** ([#2381](https://github.com/adap/flower/pull/2381), [#2805](https://github.com/adap/flower/pull/2805), [#2782](https://github.com/adap/flower/pull/2782), [#2806](https://github.com/adap/flower/pull/2806), [#2829](https://github.com/adap/flower/pull/2829), [#2825](https://github.com/adap/flower/pull/2825), [#2816](https://github.com/adap/flower/pull/2816), [#2726](https://github.com/adap/flower/pull/2726), [#2659](https://github.com/adap/flower/pull/2659), [#2655](https://github.com/adap/flower/pull/2655)) @@ -213,11 +213,11 @@ We would like to give our special thanks to all the contributors who made the ne The new simulation engine has been rewritten from the ground up, yet it remains fully backwards compatible. It offers much improved stability and memory handling, especially when working with GPUs. Simulations transparently adapt to different settings to scale simulation in CPU-only, CPU+GPU, multi-GPU, or multi-node multi-GPU environments. - Comprehensive documentation includes a new [how-to run simulations](https://flower.dev/docs/framework/how-to-run-simulations.html) guide, new [simulation-pytorch](https://flower.dev/docs/examples/simulation-pytorch.html) and [simulation-tensorflow](https://flower.dev/docs/examples/simulation-tensorflow.html) notebooks, and a new [YouTube tutorial series](https://www.youtube.com/watch?v=cRebUIGB5RU&list=PLNG4feLHqCWlnj8a_E1A_n5zr2-8pafTB). + Comprehensive documentation includes a new [how-to run simulations](https://flower.ai/docs/framework/how-to-run-simulations.html) guide, new [simulation-pytorch](https://flower.ai/docs/examples/simulation-pytorch.html) and [simulation-tensorflow](https://flower.ai/docs/examples/simulation-tensorflow.html) notebooks, and a new [YouTube tutorial series](https://www.youtube.com/watch?v=cRebUIGB5RU&list=PLNG4feLHqCWlnj8a_E1A_n5zr2-8pafTB). - **Restructure Flower Docs** ([#1824](https://github.com/adap/flower/pull/1824), [#1865](https://github.com/adap/flower/pull/1865), [#1884](https://github.com/adap/flower/pull/1884), [#1887](https://github.com/adap/flower/pull/1887), [#1919](https://github.com/adap/flower/pull/1919), [#1922](https://github.com/adap/flower/pull/1922), [#1920](https://github.com/adap/flower/pull/1920), [#1923](https://github.com/adap/flower/pull/1923), [#1924](https://github.com/adap/flower/pull/1924), [#1962](https://github.com/adap/flower/pull/1962), [#2006](https://github.com/adap/flower/pull/2006), [#2133](https://github.com/adap/flower/pull/2133), [#2203](https://github.com/adap/flower/pull/2203), [#2215](https://github.com/adap/flower/pull/2215), [#2122](https://github.com/adap/flower/pull/2122), [#2223](https://github.com/adap/flower/pull/2223), [#2219](https://github.com/adap/flower/pull/2219), [#2232](https://github.com/adap/flower/pull/2232), [#2233](https://github.com/adap/flower/pull/2233), [#2234](https://github.com/adap/flower/pull/2234), [#2235](https://github.com/adap/flower/pull/2235), [#2237](https://github.com/adap/flower/pull/2237), [#2238](https://github.com/adap/flower/pull/2238), [#2242](https://github.com/adap/flower/pull/2242), [#2231](https://github.com/adap/flower/pull/2231), [#2243](https://github.com/adap/flower/pull/2243), [#2227](https://github.com/adap/flower/pull/2227)) - Much effort went into a completely restructured Flower docs experience. The documentation on [flower.dev/docs](flower.dev/docs) is now divided into Flower Framework, Flower Baselines, Flower Android SDK, Flower iOS SDK, and code example projects. + Much effort went into a completely restructured Flower docs experience. The documentation on [flower.ai/docs](https://flower.ai/docs) is now divided into Flower Framework, Flower Baselines, Flower Android SDK, Flower iOS SDK, and code example projects. - **Introduce Flower Swift SDK** ([#1858](https://github.com/adap/flower/pull/1858), [#1897](https://github.com/adap/flower/pull/1897)) @@ -303,7 +303,7 @@ We would like to give our special thanks to all the contributors who made the ne - **Introduce new "What is Federated Learning?" tutorial** ([#1657](https://github.com/adap/flower/pull/1657), [#1721](https://github.com/adap/flower/pull/1721)) - A new [entry-level tutorial](https://flower.dev/docs/framework/tutorial-what-is-federated-learning.html) in our documentation explains the basics of Fedetated Learning. It enables anyone who's unfamiliar with Federated Learning to start their journey with Flower. Forward it to anyone who's interested in Federated Learning! + A new [entry-level tutorial](https://flower.ai/docs/framework/tutorial-what-is-federated-learning.html) in our documentation explains the basics of Fedetated Learning. It enables anyone who's unfamiliar with Federated Learning to start their journey with Flower. Forward it to anyone who's interested in Federated Learning! - **Introduce new Flower Baseline: FedProx MNIST** ([#1513](https://github.com/adap/flower/pull/1513), [#1680](https://github.com/adap/flower/pull/1680), [#1681](https://github.com/adap/flower/pull/1681), [#1679](https://github.com/adap/flower/pull/1679)) @@ -417,7 +417,7 @@ We would like to give our special thanks to all the contributors who made the ne - **Introduce new Flower Baseline: FedAvg MNIST** ([#1497](https://github.com/adap/flower/pull/1497), [#1552](https://github.com/adap/flower/pull/1552)) - Over the coming weeks, we will be releasing a number of new reference implementations useful especially to FL newcomers. They will typically revisit well known papers from the literature, and be suitable for integration in your own application or for experimentation, in order to deepen your knowledge of FL in general. Today's release is the first in this series. [Read more.](https://flower.dev/blog/2023-01-12-fl-starter-pack-fedavg-mnist-cnn/) + Over the coming weeks, we will be releasing a number of new reference implementations useful especially to FL newcomers. They will typically revisit well known papers from the literature, and be suitable for integration in your own application or for experimentation, in order to deepen your knowledge of FL in general. Today's release is the first in this series. [Read more.](https://flower.ai/blog/2023-01-12-fl-starter-pack-fedavg-mnist-cnn/) - **Improve GPU support in simulations** ([#1555](https://github.com/adap/flower/pull/1555)) @@ -427,16 +427,16 @@ We would like to give our special thanks to all the contributors who made the ne Some users reported that Jupyter Notebooks have not always been easy to use on GPU instances. We listened and made improvements to all of our Jupyter notebooks! Check out the updated notebooks here: - - [An Introduction to Federated Learning](https://flower.dev/docs/framework/tutorial-get-started-with-flower-pytorch.html) - - [Strategies in Federated Learning](https://flower.dev/docs/framework/tutorial-use-a-federated-learning-strategy-pytorch.html) - - [Building a Strategy](https://flower.dev/docs/framework/tutorial-build-a-strategy-from-scratch-pytorch.html) - - [Client and NumPyClient](https://flower.dev/docs/framework/tutorial-customize-the-client-pytorch.html) + - [An Introduction to Federated Learning](https://flower.ai/docs/framework/tutorial-get-started-with-flower-pytorch.html) + - [Strategies in Federated Learning](https://flower.ai/docs/framework/tutorial-use-a-federated-learning-strategy-pytorch.html) + - [Building a Strategy](https://flower.ai/docs/framework/tutorial-build-a-strategy-from-scratch-pytorch.html) + - [Client and NumPyClient](https://flower.ai/docs/framework/tutorial-customize-the-client-pytorch.html) - **Introduce optional telemetry** ([#1533](https://github.com/adap/flower/pull/1533), [#1544](https://github.com/adap/flower/pull/1544), [#1584](https://github.com/adap/flower/pull/1584)) After a [request for feedback](https://github.com/adap/flower/issues/1534) from the community, the Flower open-source project introduces optional collection of *anonymous* usage metrics to make well-informed decisions to improve Flower. Doing this enables the Flower team to understand how Flower is used and what challenges users might face. - **Flower is a friendly framework for collaborative AI and data science.** Staying true to this statement, Flower makes it easy to disable telemetry for users who do not want to share anonymous usage metrics. [Read more.](https://flower.dev/docs/telemetry.html). + **Flower is a friendly framework for collaborative AI and data science.** Staying true to this statement, Flower makes it easy to disable telemetry for users who do not want to share anonymous usage metrics. [Read more.](https://flower.ai/docs/telemetry.html). - **Introduce (experimental) Driver API** ([#1520](https://github.com/adap/flower/pull/1520), [#1525](https://github.com/adap/flower/pull/1525), [#1545](https://github.com/adap/flower/pull/1545), [#1546](https://github.com/adap/flower/pull/1546), [#1550](https://github.com/adap/flower/pull/1550), [#1551](https://github.com/adap/flower/pull/1551), [#1567](https://github.com/adap/flower/pull/1567)) @@ -468,7 +468,7 @@ We would like to give our special thanks to all the contributors who made the ne As usual, the documentation has improved quite a bit. It is another step in our effort to make the Flower documentation the best documentation of any project. Stay tuned and as always, feel free to provide feedback! - One highlight is the new [first time contributor guide](https://flower.dev/docs/first-time-contributors.html): if you've never contributed on GitHub before, this is the perfect place to start! + One highlight is the new [first time contributor guide](https://flower.ai/docs/first-time-contributors.html): if you've never contributed on GitHub before, this is the perfect place to start! ### Incompatible changes @@ -657,7 +657,7 @@ We would like to give our **special thanks** to all the contributors who made Fl - **Flower Baselines (preview): FedOpt, FedBN, FedAvgM** ([#919](https://github.com/adap/flower/pull/919), [#1127](https://github.com/adap/flower/pull/1127), [#914](https://github.com/adap/flower/pull/914)) - The first preview release of Flower Baselines has arrived! We're kickstarting Flower Baselines with implementations of FedOpt (FedYogi, FedAdam, FedAdagrad), FedBN, and FedAvgM. Check the documentation on how to use [Flower Baselines](https://flower.dev/docs/using-baselines.html). With this first preview release we're also inviting the community to [contribute their own baselines](https://flower.dev/docs/contributing-baselines.html). + The first preview release of Flower Baselines has arrived! We're kickstarting Flower Baselines with implementations of FedOpt (FedYogi, FedAdam, FedAdagrad), FedBN, and FedAvgM. Check the documentation on how to use [Flower Baselines](https://flower.ai/docs/using-baselines.html). With this first preview release we're also inviting the community to [contribute their own baselines](https://flower.ai/docs/baselines/how-to-contribute-baselines.html). - **C++ client SDK (preview) and code example** ([#1111](https://github.com/adap/flower/pull/1111)) @@ -703,7 +703,7 @@ We would like to give our **special thanks** to all the contributors who made Fl - New option to keep Ray running if Ray was already initialized in `start_simulation` ([#1177](https://github.com/adap/flower/pull/1177)) - Add support for custom `ClientManager` as a `start_simulation` parameter ([#1171](https://github.com/adap/flower/pull/1171)) - - New documentation for [implementing strategies](https://flower.dev/docs/framework/how-to-implement-strategies.html) ([#1097](https://github.com/adap/flower/pull/1097), [#1175](https://github.com/adap/flower/pull/1175)) + - New documentation for [implementing strategies](https://flower.ai/docs/framework/how-to-implement-strategies.html) ([#1097](https://github.com/adap/flower/pull/1097), [#1175](https://github.com/adap/flower/pull/1175)) - New mobile-friendly documentation theme ([#1174](https://github.com/adap/flower/pull/1174)) - Limit version range for (optional) `ray` dependency to include only compatible releases (`>=1.9.2,<1.12.0`) ([#1205](https://github.com/adap/flower/pull/1205)) diff --git a/doc/source/ref-example-projects.rst b/doc/source/ref-example-projects.rst index b47bd8e48997..bade86dfaa54 100644 --- a/doc/source/ref-example-projects.rst +++ b/doc/source/ref-example-projects.rst @@ -23,8 +23,8 @@ The TensorFlow/Keras quickstart example shows CIFAR-10 image classification with MobileNetV2: - `Quickstart TensorFlow (Code) `_ -- `Quickstart TensorFlow (Tutorial) `_ -- `Quickstart TensorFlow (Blog Post) `_ +- :doc:`Quickstart TensorFlow (Tutorial) ` +- `Quickstart TensorFlow (Blog Post) `_ Quickstart PyTorch @@ -34,7 +34,7 @@ The PyTorch quickstart example shows CIFAR-10 image classification with a simple Convolutional Neural Network: - `Quickstart PyTorch (Code) `_ -- `Quickstart PyTorch (Tutorial) `_ +- :doc:`Quickstart PyTorch (Tutorial) ` PyTorch: From Centralized To Federated @@ -43,7 +43,7 @@ PyTorch: From Centralized To Federated This example shows how a regular PyTorch project can be federated using Flower: - `PyTorch: From Centralized To Federated (Code) `_ -- `PyTorch: From Centralized To Federated (Tutorial) `_ +- :doc:`PyTorch: From Centralized To Federated (Tutorial) ` Federated Learning on Raspberry Pi and Nvidia Jetson @@ -52,7 +52,7 @@ Federated Learning on Raspberry Pi and Nvidia Jetson This example shows how Flower can be used to build a federated learning system that run across Raspberry Pi and Nvidia Jetson: - `Federated Learning on Raspberry Pi and Nvidia Jetson (Code) `_ -- `Federated Learning on Raspberry Pi and Nvidia Jetson (Blog Post) `_ +- `Federated Learning on Raspberry Pi and Nvidia Jetson (Blog Post) `_ @@ -60,7 +60,7 @@ Legacy Examples (`flwr_example`) -------------------------------- .. warning:: - The useage examples in `flwr_example` are deprecated and will be removed in + The usage examples in `flwr_example` are deprecated and will be removed in the future. New examples are provided as standalone projects in `examples `_. @@ -114,7 +114,7 @@ For more details, see :code:`src/py/flwr_example/pytorch_cifar`. ImageNet-2012 Image Classification ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -`ImageNet-2012 `_ is one of the major computer +`ImageNet-2012 `_ is one of the major computer vision datasets. The Flower ImageNet example uses PyTorch to train a ResNet-18 classifier in a federated learning setup with ten clients. diff --git a/doc/source/ref-faq.rst b/doc/source/ref-faq.rst index 13c44bc64b0e..26b7dca4a0a7 100644 --- a/doc/source/ref-faq.rst +++ b/doc/source/ref-faq.rst @@ -3,23 +3,23 @@ FAQ This page collects answers to commonly asked questions about Federated Learning with Flower. -.. dropdown:: :fa:`eye,mr-1` Can Flower run on Juptyter Notebooks / Google Colab? +.. dropdown:: :fa:`eye,mr-1` Can Flower run on Jupyter Notebooks / Google Colab? Yes, it can! Flower even comes with a few under-the-hood optimizations to make it work even better on Colab. Here's a quickstart example: - + * `Flower simulation PyTorch `_ * `Flower simulation TensorFlow/Keras `_ .. dropdown:: :fa:`eye,mr-1` How can I run Federated Learning on a Raspberry Pi? - Find the `blog post about federated learning on embedded device here `_ and the corresponding `GitHub code example `_. + Find the `blog post about federated learning on embedded device here `_ and the corresponding `GitHub code example `_. .. dropdown:: :fa:`eye,mr-1` Does Flower support federated learning on Android devices? - Yes, it does. Please take a look at our `blog post `_ or check out the code examples: + Yes, it does. Please take a look at our `blog post `_ or check out the code examples: - * `Android Kotlin example `_ - * `Android Java example `_ + * `Android Kotlin example `_ + * `Android Java example `_ .. dropdown:: :fa:`eye,mr-1` Can I combine federated learning with blockchain? @@ -27,6 +27,6 @@ This page collects answers to commonly asked questions about Federated Learning * `Flower meets Nevermined GitHub Repository `_. * `Flower meets Nevermined YouTube video `_. - * `Flower meets KOSMoS `_. + * `Flower meets KOSMoS `_. * `Flower meets Talan blog post `_ . * `Flower meets Talan GitHub Repository `_ . diff --git a/doc/source/ref-telemetry.md b/doc/source/ref-telemetry.md index 206e641d8b41..49efef5c8559 100644 --- a/doc/source/ref-telemetry.md +++ b/doc/source/ref-telemetry.md @@ -41,7 +41,7 @@ Flower telemetry collects the following metrics: **Source.** Flower telemetry tries to store a random source ID in `~/.flwr/source` the first time a telemetry event is generated. The source ID is important to identify whether an issue is recurring or whether an issue is triggered by multiple clusters running concurrently (which often happens in simulation). For example, if a device runs multiple workloads at the same time, and this results in an issue, then, in order to reproduce the issue, multiple workloads must be started at the same time. -You may delete the source ID at any time. If you wish for all events logged under a specific source ID to be deleted, you can send a deletion request mentioning the source ID to `telemetry@flower.dev`. All events related to that source ID will then be permanently deleted. +You may delete the source ID at any time. If you wish for all events logged under a specific source ID to be deleted, you can send a deletion request mentioning the source ID to `telemetry@flower.ai`. All events related to that source ID will then be permanently deleted. We will not collect any personally identifiable information. If you think any of the metrics collected could be misused in any way, please [get in touch with us](#how-to-contact-us). We will update this page to reflect any changes to the metrics collected and publish changes in the changelog. @@ -63,4 +63,4 @@ FLWR_TELEMETRY_ENABLED=0 FLWR_TELEMETRY_LOGGING=1 python server.py # or client.p ## How to contact us -We want to hear from you. If you have any feedback or ideas on how to improve the way we handle anonymous usage metrics, reach out to us via [Slack](https://flower.dev/join-slack/) (channel `#telemetry`) or email (`telemetry@flower.dev`). +We want to hear from you. If you have any feedback or ideas on how to improve the way we handle anonymous usage metrics, reach out to us via [Slack](https://flower.ai/join-slack/) (channel `#telemetry`) or email (`telemetry@flower.ai`). diff --git a/doc/source/tutorial-quickstart-huggingface.rst b/doc/source/tutorial-quickstart-huggingface.rst index 1e06120b452f..7d8128230901 100644 --- a/doc/source/tutorial-quickstart-huggingface.rst +++ b/doc/source/tutorial-quickstart-huggingface.rst @@ -9,8 +9,8 @@ Quickstart 🤗 Transformers Let's build a federated learning system using Hugging Face Transformers and Flower! -We will leverage Hugging Face to federate the training of language models over multiple clients using Flower. -More specifically, we will fine-tune a pre-trained Transformer model (distilBERT) +We will leverage Hugging Face to federate the training of language models over multiple clients using Flower. +More specifically, we will fine-tune a pre-trained Transformer model (distilBERT) for sequence classification over a dataset of IMDB ratings. The end goal is to detect if a movie rating is positive or negative. @@ -32,8 +32,8 @@ Standard Hugging Face workflow Handling the data ^^^^^^^^^^^^^^^^^ -To fetch the IMDB dataset, we will use Hugging Face's :code:`datasets` library. -We then need to tokenize the data and create :code:`PyTorch` dataloaders, +To fetch the IMDB dataset, we will use Hugging Face's :code:`datasets` library. +We then need to tokenize the data and create :code:`PyTorch` dataloaders, this is all done in the :code:`load_data` function: .. code-block:: python @@ -80,8 +80,8 @@ this is all done in the :code:`load_data` function: Training and testing the model ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Once we have a way of creating our trainloader and testloader, -we can take care of the training and testing. +Once we have a way of creating our trainloader and testloader, +we can take care of the training and testing. This is very similar to any :code:`PyTorch` training or testing loop: .. code-block:: python @@ -120,12 +120,12 @@ This is very similar to any :code:`PyTorch` training or testing loop: Creating the model itself ^^^^^^^^^^^^^^^^^^^^^^^^^ -To create the model itself, +To create the model itself, we will just load the pre-trained distillBERT model using Hugging Face’s :code:`AutoModelForSequenceClassification` : .. code-block:: python - from transformers import AutoModelForSequenceClassification + from transformers import AutoModelForSequenceClassification net = AutoModelForSequenceClassification.from_pretrained( CHECKPOINT, num_labels=2 @@ -138,8 +138,8 @@ Federating the example Creating the IMDBClient ^^^^^^^^^^^^^^^^^^^^^^^ -To federate our example to multiple clients, -we first need to write our Flower client class (inheriting from :code:`flwr.client.NumPyClient`). +To federate our example to multiple clients, +we first need to write our Flower client class (inheriting from :code:`flwr.client.NumPyClient`). This is very easy, as our model is a standard :code:`PyTorch` model: .. code-block:: python @@ -166,17 +166,17 @@ This is very easy, as our model is a standard :code:`PyTorch` model: return float(loss), len(testloader), {"accuracy": float(accuracy)} -The :code:`get_parameters` function lets the server get the client's parameters. -Inversely, the :code:`set_parameters` function allows the server to send its parameters to the client. -Finally, the :code:`fit` function trains the model locally for the client, -and the :code:`evaluate` function tests the model locally and returns the relevant metrics. +The :code:`get_parameters` function lets the server get the client's parameters. +Inversely, the :code:`set_parameters` function allows the server to send its parameters to the client. +Finally, the :code:`fit` function trains the model locally for the client, +and the :code:`evaluate` function tests the model locally and returns the relevant metrics. Starting the server ^^^^^^^^^^^^^^^^^^^ -Now that we have a way to instantiate clients, we need to create our server in order to aggregate the results. -Using Flower, this can be done very easily by first choosing a strategy (here, we are using :code:`FedAvg`, -which will define the global weights as the average of all the clients' weights at each round) +Now that we have a way to instantiate clients, we need to create our server in order to aggregate the results. +Using Flower, this can be done very easily by first choosing a strategy (here, we are using :code:`FedAvg`, +which will define the global weights as the average of all the clients' weights at each round) and then using the :code:`flwr.server.start_server` function: .. code-block:: python @@ -186,7 +186,7 @@ and then using the :code:`flwr.server.start_server` function: losses = [num_examples * m["loss"] for num_examples, m in metrics] examples = [num_examples for num_examples, _ in metrics] return {"accuracy": sum(accuracies) / sum(examples), "loss": sum(losses) / sum(examples)} - + # Define strategy strategy = fl.server.strategy.FedAvg( fraction_fit=1.0, @@ -202,7 +202,7 @@ and then using the :code:`flwr.server.start_server` function: ) -The :code:`weighted_average` function is there to provide a way to aggregate the metrics distributed amongst +The :code:`weighted_average` function is there to provide a way to aggregate the metrics distributed amongst the clients (basically this allows us to display a nice average accuracy and loss for every round). Putting everything together @@ -213,18 +213,17 @@ We can now start client instances using: .. code-block:: python fl.client.start_client( - server_address="127.0.0.1:8080", + server_address="127.0.0.1:8080", client=IMDBClient().to_client() ) And they will be able to connect to the server and start the federated training. -If you want to check out everything put together, -you should check out the full code example: -[https://github.com/adap/flower/tree/main/examples/quickstart-huggingface](https://github.com/adap/flower/tree/main/examples/quickstart-huggingface). +If you want to check out everything put together, +you should check out the `full code example `_ . -Of course, this is a very basic example, and a lot can be added or modified, +Of course, this is a very basic example, and a lot can be added or modified, it was just to showcase how simply we could federate a Hugging Face workflow using Flower. Note that in this example we used :code:`PyTorch`, but we could have very well used :code:`TensorFlow`. diff --git a/doc/source/tutorial-quickstart-ios.rst b/doc/source/tutorial-quickstart-ios.rst index 7c8007baaa75..e4315ce569fb 100644 --- a/doc/source/tutorial-quickstart-ios.rst +++ b/doc/source/tutorial-quickstart-ios.rst @@ -7,14 +7,14 @@ Quickstart iOS .. meta:: :description: Read this Federated Learning quickstart tutorial for creating an iOS app using Flower to train a neural network on MNIST. -In this tutorial we will learn how to train a Neural Network on MNIST using Flower and CoreML on iOS devices. +In this tutorial we will learn how to train a Neural Network on MNIST using Flower and CoreML on iOS devices. -First of all, for running the Flower Python server, it is recommended to create a virtual environment and run everything within a `virtualenv `_. +First of all, for running the Flower Python server, it is recommended to create a virtual environment and run everything within a :doc:`virtualenv `. For the Flower client implementation in iOS, it is recommended to use Xcode as our IDE. -Our example consists of one Python *server* and two iPhone *clients* that all have the same model. +Our example consists of one Python *server* and two iPhone *clients* that all have the same model. -*Clients* are responsible for generating individual weight updates for the model based on their local datasets. +*Clients* are responsible for generating individual weight updates for the model based on their local datasets. These updates are then sent to the *server* which will aggregate them to produce a better model. Finally, the *server* sends this improved version of the model back to each *client*. A complete cycle of weight updates is called a *round*. @@ -44,10 +44,10 @@ For simplicity reasons we will use the complete Flower client with CoreML, that public func getParameters() -> GetParametersRes { let parameters = parameters.weightsToParameters() let status = Status(code: .ok, message: String()) - + return GetParametersRes(parameters: parameters, status: status) } - + /// Calls the routine to fit the local model /// /// - Returns: The result from the local training, e.g., updated parameters @@ -55,17 +55,17 @@ For simplicity reasons we will use the complete Flower client with CoreML, that let status = Status(code: .ok, message: String()) let result = runMLTask(configuration: parameters.parametersToWeights(parameters: ins.parameters), task: .train) let parameters = parameters.weightsToParameters() - + return FitRes(parameters: parameters, numExamples: result.numSamples, status: status) } - + /// Calls the routine to evaluate the local model /// /// - Returns: The result from the evaluation, e.g., loss public func evaluate(ins: EvaluateIns) -> EvaluateRes { let status = Status(code: .ok, message: String()) let result = runMLTask(configuration: parameters.parametersToWeights(parameters: ins.parameters), task: .test) - + return EvaluateRes(loss: Float(result.loss), numExamples: result.numSamples, status: status) } @@ -88,18 +88,18 @@ For the MNIST dataset, we need to preprocess it into :code:`MLBatchProvider` obj // prepare train dataset let trainBatchProvider = DataLoader.trainBatchProvider() { _ in } - + // prepare test dataset let testBatchProvider = DataLoader.testBatchProvider() { _ in } - + // load them together - let dataLoader = MLDataLoader(trainBatchProvider: trainBatchProvider, + let dataLoader = MLDataLoader(trainBatchProvider: trainBatchProvider, testBatchProvider: testBatchProvider) Since CoreML does not allow the model parameters to be seen before training, and accessing the model parameters during or after the training can only be done by specifying the layer name, -we need to know this informations beforehand, through looking at the model specification, which are written as proto files. The implementation can be seen in :code:`MLModelInspect`. +we need to know this information beforehand, through looking at the model specification, which are written as proto files. The implementation can be seen in :code:`MLModelInspect`. -After we have all of the necessary informations, let's create our Flower client. +After we have all of the necessary information, let's create our Flower client. .. code-block:: swift @@ -122,7 +122,7 @@ Then start the Flower gRPC client and start communicating to the server by passi self.flwrGRPC.startFlwrGRPC(client: self.mlFlwrClient) That's it for the client. We only have to implement :code:`Client` or call the provided -:code:`MLFlwrClient` and call :code:`startFlwrGRPC()`. The attribute :code:`hostname` and :code:`port` tells the client which server to connect to. +:code:`MLFlwrClient` and call :code:`startFlwrGRPC()`. The attribute :code:`hostname` and :code:`port` tells the client which server to connect to. This can be done by entering the hostname and port in the application before clicking the start button to start the federated learning process. Flower Server diff --git a/doc/source/tutorial-quickstart-mxnet.rst b/doc/source/tutorial-quickstart-mxnet.rst index ff8d4b2087dd..fe582f793280 100644 --- a/doc/source/tutorial-quickstart-mxnet.rst +++ b/doc/source/tutorial-quickstart-mxnet.rst @@ -4,18 +4,18 @@ Quickstart MXNet ================ -.. warning:: MXNet is no longer maintained and has been moved into `Attic `_. As a result, we would encourage you to use other ML frameworks alongise Flower, for example, PyTorch. This tutorial might be removed in future versions of Flower. +.. warning:: MXNet is no longer maintained and has been moved into `Attic `_. As a result, we would encourage you to use other ML frameworks alongside Flower, for example, PyTorch. This tutorial might be removed in future versions of Flower. .. meta:: :description: Check out this Federated Learning quickstart tutorial for using Flower with MXNet to train a Sequential model on MNIST. -In this tutorial, we will learn how to train a :code:`Sequential` model on MNIST using Flower and MXNet. +In this tutorial, we will learn how to train a :code:`Sequential` model on MNIST using Flower and MXNet. -It is recommended to create a virtual environment and run everything within this `virtualenv `_. +It is recommended to create a virtual environment and run everything within this :doc:`virtualenv `. -Our example consists of one *server* and two *clients* all having the same model. +Our example consists of one *server* and two *clients* all having the same model. -*Clients* are responsible for generating individual model parameter updates for the model based on their local datasets. +*Clients* are responsible for generating individual model parameter updates for the model based on their local datasets. These updates are then sent to the *server* which will aggregate them to produce an updated global model. Finally, the *server* sends this improved version of the model back to each *client*. A complete cycle of parameters updates is called a *round*. @@ -35,12 +35,12 @@ Since we want to use MXNet, let's go ahead and install it: Flower Client ------------- -Now that we have all our dependencies installed, let's run a simple distributed training with two clients and one server. Our training procedure and network architecture are based on MXNet´s `Hand-written Digit Recognition tutorial `_. +Now that we have all our dependencies installed, let's run a simple distributed training with two clients and one server. Our training procedure and network architecture are based on MXNet´s `Hand-written Digit Recognition tutorial `_. In a file called :code:`client.py`, import Flower and MXNet related packages: .. code-block:: python - + import flwr as fl import numpy as np @@ -58,7 +58,7 @@ In addition, define the device allocation in MXNet with: DEVICE = [mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()] -We use MXNet to load MNIST, a popular image classification dataset of handwritten digits for machine learning. The MXNet utility :code:`mx.test_utils.get_mnist()` downloads the training and test data. +We use MXNet to load MNIST, a popular image classification dataset of handwritten digits for machine learning. The MXNet utility :code:`mx.test_utils.get_mnist()` downloads the training and test data. .. code-block:: python @@ -72,7 +72,7 @@ We use MXNet to load MNIST, a popular image classification dataset of handwritte val_data = mx.io.NDArrayIter(mnist["test_data"], mnist["test_label"], batch_size) return train_data, val_data -Define the training and loss with MXNet. We train the model by looping over the dataset, measure the corresponding loss, and optimize it. +Define the training and loss with MXNet. We train the model by looping over the dataset, measure the corresponding loss, and optimize it. .. code-block:: python @@ -110,7 +110,7 @@ Define the training and loss with MXNet. We train the model by looping over the return trainings_metric, num_examples -Next, we define the validation of our machine learning model. We loop over the test set and measure both loss and accuracy on the test set. +Next, we define the validation of our machine learning model. We loop over the test set and measure both loss and accuracy on the test set. .. code-block:: python @@ -155,7 +155,7 @@ Our Flower clients will use a simple :code:`Sequential` model: init = nd.random.uniform(shape=(2, 784)) model(init) -After loading the dataset with :code:`load_data()` we perform one forward propagation to initialize the model and model parameters with :code:`model(init)`. Next, we implement a Flower client. +After loading the dataset with :code:`load_data()` we perform one forward propagation to initialize the model and model parameters with :code:`model(init)`. Next, we implement a Flower client. The Flower server interacts with clients through an interface called :code:`Client`. When the server selects a particular client for training, it @@ -207,7 +207,7 @@ They can be implemented in the following way: [accuracy, loss], num_examples = test(model, val_data) print("Evaluation accuracy & loss", accuracy, loss) return float(loss[1]), val_data.batch_size, {"accuracy": float(accuracy[1])} - + We can now create an instance of our class :code:`MNISTClient` and add one line to actually run this client: diff --git a/doc/source/tutorial-quickstart-pytorch.rst b/doc/source/tutorial-quickstart-pytorch.rst index f15a4a93114e..895590808a2b 100644 --- a/doc/source/tutorial-quickstart-pytorch.rst +++ b/doc/source/tutorial-quickstart-pytorch.rst @@ -10,13 +10,13 @@ Quickstart PyTorch .. youtube:: jOmmuzMIQ4c :width: 100% -In this tutorial we will learn how to train a Convolutional Neural Network on CIFAR10 using Flower and PyTorch. +In this tutorial we will learn how to train a Convolutional Neural Network on CIFAR10 using Flower and PyTorch. -First of all, it is recommended to create a virtual environment and run everything within a `virtualenv `_. +First of all, it is recommended to create a virtual environment and run everything within a :doc:`virtualenv `. -Our example consists of one *server* and two *clients* all having the same model. +Our example consists of one *server* and two *clients* all having the same model. -*Clients* are responsible for generating individual weight-updates for the model based on their local datasets. +*Clients* are responsible for generating individual weight-updates for the model based on their local datasets. These updates are then sent to the *server* which will aggregate them to produce a better model. Finally, the *server* sends this improved version of the model back to each *client*. A complete cycle of weight updates is called a *round*. @@ -26,7 +26,7 @@ Now that we have a rough idea of what is going on, let's get started. We first n $ pip install flwr -Since we want to use PyTorch to solve a computer vision task, let's go ahead and install PyTorch and the **torchvision** library: +Since we want to use PyTorch to solve a computer vision task, let's go ahead and install PyTorch and the **torchvision** library: .. code-block:: shell @@ -36,12 +36,12 @@ Since we want to use PyTorch to solve a computer vision task, let's go ahead and Flower Client ------------- -Now that we have all our dependencies installed, let's run a simple distributed training with two clients and one server. Our training procedure and network architecture are based on PyTorch's `Deep Learning with PyTorch `_. +Now that we have all our dependencies installed, let's run a simple distributed training with two clients and one server. Our training procedure and network architecture are based on PyTorch's `Deep Learning with PyTorch `_. In a file called :code:`client.py`, import Flower and PyTorch related packages: .. code-block:: python - + from collections import OrderedDict import torch @@ -59,7 +59,7 @@ In addition, we define the device allocation in PyTorch with: DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -We use PyTorch to load CIFAR10, a popular colored image classification dataset for machine learning. The PyTorch :code:`DataLoader()` downloads the training and test data that are then normalized. +We use PyTorch to load CIFAR10, a popular colored image classification dataset for machine learning. The PyTorch :code:`DataLoader()` downloads the training and test data that are then normalized. .. code-block:: python @@ -75,7 +75,7 @@ We use PyTorch to load CIFAR10, a popular colored image classification dataset f num_examples = {"trainset" : len(trainset), "testset" : len(testset)} return trainloader, testloader, num_examples -Define the loss and optimizer with PyTorch. The training of the dataset is done by looping over the dataset, measure the corresponding loss and optimize it. +Define the loss and optimizer with PyTorch. The training of the dataset is done by looping over the dataset, measure the corresponding loss and optimize it. .. code-block:: python @@ -91,7 +91,7 @@ Define the loss and optimizer with PyTorch. The training of the dataset is done loss.backward() optimizer.step() -Define then the validation of the machine learning network. We loop over the test set and measure the loss and accuracy of the test set. +Define then the validation of the machine learning network. We loop over the test set and measure the loss and accuracy of the test set. .. code-block:: python @@ -139,7 +139,7 @@ The Flower clients will use a simple CNN adapted from 'PyTorch: A 60 Minute Blit net = Net().to(DEVICE) trainloader, testloader, num_examples = load_data() -After loading the data set with :code:`load_data()` we define the Flower interface. +After loading the data set with :code:`load_data()` we define the Flower interface. The Flower server interacts with clients through an interface called :code:`Client`. When the server selects a particular client for training, it diff --git a/doc/source/tutorial-quickstart-scikitlearn.rst b/doc/source/tutorial-quickstart-scikitlearn.rst index 4921f63bab2c..d1d47dc37f19 100644 --- a/doc/source/tutorial-quickstart-scikitlearn.rst +++ b/doc/source/tutorial-quickstart-scikitlearn.rst @@ -7,13 +7,13 @@ Quickstart scikit-learn .. meta:: :description: Check out this Federated Learning quickstart tutorial for using Flower with scikit-learn to train a linear regression model. -In this tutorial, we will learn how to train a :code:`Logistic Regression` model on MNIST using Flower and scikit-learn. +In this tutorial, we will learn how to train a :code:`Logistic Regression` model on MNIST using Flower and scikit-learn. -It is recommended to create a virtual environment and run everything within this `virtualenv `_. +It is recommended to create a virtual environment and run everything within this :doc:`virtualenv `. -Our example consists of one *server* and two *clients* all having the same model. +Our example consists of one *server* and two *clients* all having the same model. -*Clients* are responsible for generating individual model parameter updates for the model based on their local datasets. +*Clients* are responsible for generating individual model parameter updates for the model based on their local datasets. These updates are then sent to the *server* which will aggregate them to produce an updated global model. Finally, the *server* sends this improved version of the model back to each *client*. A complete cycle of parameters updates is called a *round*. @@ -23,7 +23,7 @@ Now that we have a rough idea of what is going on, let's get started. We first n $ pip install flwr -Since we want to use scikt-learn, let's go ahead and install it: +Since we want to use scikit-learn, let's go ahead and install it: .. code-block:: shell @@ -43,7 +43,7 @@ Now that we have all our dependencies installed, let's run a simple distributed However, before setting up the client and server, we will define all functionalities that we need for our federated learning setup within :code:`utils.py`. The :code:`utils.py` contains different functions defining all the machine learning basics: * :code:`get_model_parameters()` - * Returns the paramters of a :code:`sklearn` LogisticRegression model + * Returns the parameters of a :code:`sklearn` LogisticRegression model * :code:`set_model_params()` * Sets the parameters of a :code:`sklean` LogisticRegression model * :code:`set_initial_params()` @@ -59,7 +59,7 @@ Please check out :code:`utils.py` `here `_, a popular image classification dataset of handwritten digits for machine learning. The utility :code:`utils.load_mnist()` downloads the training and test data. The training set is split afterwards into 10 partitions with :code:`utils.partition()`. +We load the MNIST dataset from `OpenML `_, a popular image classification dataset of handwritten digits for machine learning. The utility :code:`utils.load_mnist()` downloads the training and test data. The training set is split afterwards into 10 partitions with :code:`utils.partition()`. .. code-block:: python diff --git a/doc/source/tutorial-quickstart-xgboost.rst b/doc/source/tutorial-quickstart-xgboost.rst index 3a7b356c4d2a..751024db14e4 100644 --- a/doc/source/tutorial-quickstart-xgboost.rst +++ b/doc/source/tutorial-quickstart-xgboost.rst @@ -36,7 +36,7 @@ and then we dive into a more complex example (`full code xgboost-comprehensive < Environment Setup -------------------- -First of all, it is recommended to create a virtual environment and run everything within a `virtualenv `_. +First of all, it is recommended to create a virtual environment and run everything within a :doc:`virtualenv `. We first need to install Flower and Flower Datasets. You can do this by running : @@ -596,7 +596,7 @@ Comprehensive Federated XGBoost Now that you have known how federated XGBoost work with Flower, it's time to run some more comprehensive experiments by customising the experimental settings. In the xgboost-comprehensive example (`full code `_), we provide more options to define various experimental setups, including aggregation strategies, data partitioning and centralised/distributed evaluation. -We also support `Flower simulation `_ making it easy to simulate large client cohorts in a resource-aware manner. +We also support :doc:`Flower simulation ` making it easy to simulate large client cohorts in a resource-aware manner. Let's take a look! Cyclic training diff --git a/doc/source/tutorial-series-build-a-strategy-from-scratch-pytorch.ipynb b/doc/source/tutorial-series-build-a-strategy-from-scratch-pytorch.ipynb index 5b2236468909..c5fc777e7f26 100644 --- a/doc/source/tutorial-series-build-a-strategy-from-scratch-pytorch.ipynb +++ b/doc/source/tutorial-series-build-a-strategy-from-scratch-pytorch.ipynb @@ -7,11 +7,11 @@ "source": [ "# Build a strategy from scratch\n", "\n", - "Welcome to the third part of the Flower federated learning tutorial. In previous parts of this tutorial, we introduced federated learning with PyTorch and Flower ([part 1](https://flower.dev/docs/framework/tutorial-get-started-with-flower-pytorch.html)) and we learned how strategies can be used to customize the execution on both the server and the clients ([part 2](https://flower.dev/docs/framework/tutorial-use-a-federated-learning-strategy-pytorch.html)).\n", + "Welcome to the third part of the Flower federated learning tutorial. In previous parts of this tutorial, we introduced federated learning with PyTorch and Flower ([part 1](https://flower.ai/docs/framework/tutorial-get-started-with-flower-pytorch.html)) and we learned how strategies can be used to customize the execution on both the server and the clients ([part 2](https://flower.ai/docs/framework/tutorial-use-a-federated-learning-strategy-pytorch.html)).\n", "\n", - "In this notebook, we'll continue to customize the federated learning system we built previously by creating a custom version of FedAvg (again, using [Flower](https://flower.dev/) and [PyTorch](https://pytorch.org/)).\n", + "In this notebook, we'll continue to customize the federated learning system we built previously by creating a custom version of FedAvg (again, using [Flower](https://flower.ai/) and [PyTorch](https://pytorch.org/)).\n", "\n", - "> [Star Flower on GitHub](https://github.com/adap/flower) ⭐️ and join the Flower community on Slack to connect, ask questions, and get help: [Join Slack](https://flower.dev/join-slack) 🌼 We'd love to hear from you in the `#introductions` channel! And if anything is unclear, head over to the `#questions` channel.\n", + "> [Star Flower on GitHub](https://github.com/adap/flower) ⭐️ and join the Flower community on Slack to connect, ask questions, and get help: [Join Slack](https://flower.ai/join-slack) 🌼 We'd love to hear from you in the `#introductions` channel! And if anything is unclear, head over to the `#questions` channel.\n", "\n", "Let's build a new `Strategy` from scratch!" ] @@ -489,11 +489,11 @@ "source": [ "## Next steps\n", "\n", - "Before you continue, make sure to join the Flower community on Slack: [Join Slack](https://flower.dev/join-slack/)\n", + "Before you continue, make sure to join the Flower community on Slack: [Join Slack](https://flower.ai/join-slack/)\n", "\n", "There's a dedicated `#questions` channel if you need help, but we'd also love to hear who you are in `#introductions`!\n", "\n", - "The [Flower Federated Learning Tutorial - Part 4](https://flower.dev/docs/framework/tutorial-customize-the-client-pytorch.html) introduces `Client`, the flexible API underlying `NumPyClient`." + "The [Flower Federated Learning Tutorial - Part 4](https://flower.ai/docs/framework/tutorial-customize-the-client-pytorch.html) introduces `Client`, the flexible API underlying `NumPyClient`." ] } ], diff --git a/doc/source/tutorial-series-customize-the-client-pytorch.ipynb b/doc/source/tutorial-series-customize-the-client-pytorch.ipynb index 0ff67de6f51d..ce09c6cc46c1 100644 --- a/doc/source/tutorial-series-customize-the-client-pytorch.ipynb +++ b/doc/source/tutorial-series-customize-the-client-pytorch.ipynb @@ -7,11 +7,11 @@ "source": [ "# Customize the client\n", "\n", - "Welcome to the fourth part of the Flower federated learning tutorial. In the previous parts of this tutorial, we introduced federated learning with PyTorch and Flower ([part 1](https://flower.dev/docs/framework/tutorial-get-started-with-flower-pytorch.html)), we learned how strategies can be used to customize the execution on both the server and the clients ([part 2](https://flower.dev/docs/framework/tutorial-use-a-federated-learning-strategy-pytorch.html)), and we built our own custom strategy from scratch ([part 3](https://flower.dev/docs/framework/tutorial-build-a-strategy-from-scratch-pytorch.html)).\n", + "Welcome to the fourth part of the Flower federated learning tutorial. In the previous parts of this tutorial, we introduced federated learning with PyTorch and Flower ([part 1](https://flower.ai/docs/framework/tutorial-get-started-with-flower-pytorch.html)), we learned how strategies can be used to customize the execution on both the server and the clients ([part 2](https://flower.ai/docs/framework/tutorial-use-a-federated-learning-strategy-pytorch.html)), and we built our own custom strategy from scratch ([part 3](https://flower.ai/docs/framework/tutorial-build-a-strategy-from-scratch-pytorch.html)).\n", "\n", "In this notebook, we revisit `NumPyClient` and introduce a new baseclass for building clients, simply named `Client`. In previous parts of this tutorial, we've based our client on `NumPyClient`, a convenience class which makes it easy to work with machine learning libraries that have good NumPy interoperability. With `Client`, we gain a lot of flexibility that we didn't have before, but we'll also have to do a few things the we didn't have to do before.\n", "\n", - "> [Star Flower on GitHub](https://github.com/adap/flower) ⭐️ and join the Flower community on Slack to connect, ask questions, and get help: [Join Slack](https://flower.dev/join-slack) 🌼 We'd love to hear from you in the `#introductions` channel! And if anything is unclear, head over to the `#questions` channel.\n", + "> [Star Flower on GitHub](https://github.com/adap/flower) ⭐️ and join the Flower community on Slack to connect, ask questions, and get help: [Join Slack](https://flower.ai/join-slack) 🌼 We'd love to hear from you in the `#introductions` channel! And if anything is unclear, head over to the `#questions` channel.\n", "\n", "Let's go deeper and see what it takes to move from `NumPyClient` to `Client`!" ] @@ -869,16 +869,16 @@ "source": [ "## Next steps\n", "\n", - "Before you continue, make sure to join the Flower community on Slack: [Join Slack](https://flower.dev/join-slack/)\n", + "Before you continue, make sure to join the Flower community on Slack: [Join Slack](https://flower.ai/join-slack/)\n", "\n", "There's a dedicated `#questions` channel if you need help, but we'd also love to hear who you are in `#introductions`!\n", "\n", "This is the final part of the Flower tutorial (for now!), congratulations! You're now well equipped to understand the rest of the documentation. There are many topics we didn't cover in the tutorial, we recommend the following resources:\n", "\n", - "- [Read Flower Docs](https://flower.dev/docs/)\n", + "- [Read Flower Docs](https://flower.ai/docs/)\n", "- [Check out Flower Code Examples](https://github.com/adap/flower/tree/main/examples)\n", - "- [Use Flower Baselines for your research](https://flower.dev/docs/baselines/)\n", - "- [Watch Flower Summit 2023 videos](https://flower.dev/conf/flower-summit-2023/)\n" + "- [Use Flower Baselines for your research](https://flower.ai/docs/baselines/)\n", + "- [Watch Flower Summit 2023 videos](https://flower.ai/conf/flower-summit-2023/)\n" ] } ], diff --git a/doc/source/tutorial-series-get-started-with-flower-pytorch.ipynb b/doc/source/tutorial-series-get-started-with-flower-pytorch.ipynb index 704ed520bf3e..205531c54ee6 100644 --- a/doc/source/tutorial-series-get-started-with-flower-pytorch.ipynb +++ b/doc/source/tutorial-series-get-started-with-flower-pytorch.ipynb @@ -9,9 +9,9 @@ "\n", "Welcome to the Flower federated learning tutorial!\n", "\n", - "In this notebook, we'll build a federated learning system using Flower, [Flower Datasets](https://flower.dev/docs/datasets/) and PyTorch. In part 1, we use PyTorch for the model training pipeline and data loading. In part 2, we continue to federate the PyTorch-based pipeline using Flower.\n", + "In this notebook, we'll build a federated learning system using Flower, [Flower Datasets](https://flower.ai/docs/datasets/) and PyTorch. In part 1, we use PyTorch for the model training pipeline and data loading. In part 2, we continue to federate the PyTorch-based pipeline using Flower.\n", "\n", - "> [Star Flower on GitHub](https://github.com/adap/flower) ⭐️ and join the Flower community on Slack to connect, ask questions, and get help: [Join Slack](https://flower.dev/join-slack) 🌼 We'd love to hear from you in the `#introductions` channel! And if anything is unclear, head over to the `#questions` channel.\n", + "> [Star Flower on GitHub](https://github.com/adap/flower) ⭐️ and join the Flower community on Slack to connect, ask questions, and get help: [Join Slack](https://flower.ai/join-slack) 🌼 We'd love to hear from you in the `#introductions` channel! And if anything is unclear, head over to the `#questions` channel.\n", "\n", "Let's get stated!" ] @@ -605,11 +605,11 @@ "source": [ "## Next steps\n", "\n", - "Before you continue, make sure to join the Flower community on Slack: [Join Slack](https://flower.dev/join-slack/)\n", + "Before you continue, make sure to join the Flower community on Slack: [Join Slack](https://flower.ai/join-slack/)\n", "\n", "There's a dedicated `#questions` channel if you need help, but we'd also love to hear who you are in `#introductions`!\n", "\n", - "The [Flower Federated Learning Tutorial - Part 2](https://flower.dev/docs/framework/tutorial-use-a-federated-learning-strategy-pytorch.html) goes into more depth about strategies and all the advanced things you can build with them.\n" + "The [Flower Federated Learning Tutorial - Part 2](https://flower.ai/docs/framework/tutorial-use-a-federated-learning-strategy-pytorch.html) goes into more depth about strategies and all the advanced things you can build with them.\n" ] } ], diff --git a/doc/source/tutorial-series-use-a-federated-learning-strategy-pytorch.ipynb b/doc/source/tutorial-series-use-a-federated-learning-strategy-pytorch.ipynb index 06f53cd8e1b1..e20a8d83f674 100644 --- a/doc/source/tutorial-series-use-a-federated-learning-strategy-pytorch.ipynb +++ b/doc/source/tutorial-series-use-a-federated-learning-strategy-pytorch.ipynb @@ -7,11 +7,11 @@ "source": [ "# Use a federated learning strategy\n", "\n", - "Welcome to the next part of the federated learning tutorial. In previous parts of this tutorial, we introduced federated learning with PyTorch and Flower ([part 1](https://flower.dev/docs/framework/tutorial-get-started-with-flower-pytorch.html)).\n", + "Welcome to the next part of the federated learning tutorial. In previous parts of this tutorial, we introduced federated learning with PyTorch and Flower ([part 1](https://flower.ai/docs/framework/tutorial-get-started-with-flower-pytorch.html)).\n", "\n", - "In this notebook, we'll begin to customize the federated learning system we built in the introductory notebook (again, using [Flower](https://flower.dev/) and [PyTorch](https://pytorch.org/)).\n", + "In this notebook, we'll begin to customize the federated learning system we built in the introductory notebook (again, using [Flower](https://flower.ai/) and [PyTorch](https://pytorch.org/)).\n", "\n", - "> [Star Flower on GitHub](https://github.com/adap/flower) ⭐️ and join the Flower community on Slack to connect, ask questions, and get help: [Join Slack](https://flower.dev/join-slack) 🌼 We'd love to hear from you in the `#introductions` channel! And if anything is unclear, head over to the `#questions` channel.\n", + "> [Star Flower on GitHub](https://github.com/adap/flower) ⭐️ and join the Flower community on Slack to connect, ask questions, and get help: [Join Slack](https://flower.ai/join-slack) 🌼 We'd love to hear from you in the `#introductions` channel! And if anything is unclear, head over to the `#questions` channel.\n", "\n", "Let's move beyond FedAvg with Flower strategies!" ] @@ -614,11 +614,11 @@ "source": [ "## Next steps\n", "\n", - "Before you continue, make sure to join the Flower community on Slack: [Join Slack](https://flower.dev/join-slack/)\n", + "Before you continue, make sure to join the Flower community on Slack: [Join Slack](https://flower.ai/join-slack/)\n", "\n", "There's a dedicated `#questions` channel if you need help, but we'd also love to hear who you are in `#introductions`!\n", "\n", - "The [Flower Federated Learning Tutorial - Part 3](https://flower.dev/docs/framework/tutorial-build-a-strategy-from-scratch-pytorch.html) shows how to build a fully custom `Strategy` from scratch." + "The [Flower Federated Learning Tutorial - Part 3](https://flower.ai/docs/framework/tutorial-build-a-strategy-from-scratch-pytorch.html) shows how to build a fully custom `Strategy` from scratch." ] } ], diff --git a/doc/source/tutorial-series-what-is-federated-learning.ipynb b/doc/source/tutorial-series-what-is-federated-learning.ipynb index 3f7e383b9fbc..d77182838f21 100755 --- a/doc/source/tutorial-series-what-is-federated-learning.ipynb +++ b/doc/source/tutorial-series-what-is-federated-learning.ipynb @@ -13,7 +13,7 @@ "\n", "🧑‍🏫 This tutorial starts at zero and expects no familiarity with federated learning. Only a basic understanding of data science and Python programming is assumed.\n", "\n", - "> [Star Flower on GitHub](https://github.com/adap/flower) ⭐️ and join the open-source Flower community on Slack to connect, ask questions, and get help: [Join Slack](https://flower.dev/join-slack) 🌼 We'd love to hear from you in the `#introductions` channel! And if anything is unclear, head over to the `#questions` channel.\n", + "> [Star Flower on GitHub](https://github.com/adap/flower) ⭐️ and join the open-source Flower community on Slack to connect, ask questions, and get help: [Join Slack](https://flower.ai/join-slack) 🌼 We'd love to hear from you in the `#introductions` channel! And if anything is unclear, head over to the `#questions` channel.\n", "\n", "Let's get started!" ] @@ -217,11 +217,11 @@ "source": [ "## Next steps\n", "\n", - "Before you continue, make sure to join the Flower community on Slack: [Join Slack](https://flower.dev/join-slack/)\n", + "Before you continue, make sure to join the Flower community on Slack: [Join Slack](https://flower.ai/join-slack/)\n", "\n", "There's a dedicated `#questions` channel if you need help, but we'd also love to hear who you are in `#introductions`!\n", "\n", - "The [Flower Federated Learning Tutorial - Part 1](https://flower.dev/docs/framework/tutorial-get-started-with-flower-pytorch.html) shows how to build a simple federated learning system with PyTorch and Flower." + "The [Flower Federated Learning Tutorial - Part 1](https://flower.ai/docs/framework/tutorial-get-started-with-flower-pytorch.html) shows how to build a simple federated learning system with PyTorch and Flower." ] } ], diff --git a/e2e/bare-https/driver.py b/e2e/bare-https/driver.py index dd7c9eab7248..f7bfeb613f6a 100644 --- a/e2e/bare-https/driver.py +++ b/e2e/bare-https/driver.py @@ -3,7 +3,7 @@ # Start Flower server -hist = fl.server.driver.start_driver( +hist = fl.server.start_driver( server_address="127.0.0.1:9091", config=fl.server.ServerConfig(num_rounds=3), root_certificates=Path("certificates/ca.crt").read_bytes(), diff --git a/e2e/bare-https/pyproject.toml b/e2e/bare-https/pyproject.toml index 9489a43195f9..3afb7b57a084 100644 --- a/e2e/bare-https/pyproject.toml +++ b/e2e/bare-https/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" name = "bare_https_test" version = "0.1.0" description = "HTTPS-enabled bare Federated Learning test with Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = "^3.8" diff --git a/e2e/bare/client.py b/e2e/bare/client.py index a9425b39778a..c291fb0963e4 100644 --- a/e2e/bare/client.py +++ b/e2e/bare/client.py @@ -3,7 +3,7 @@ import flwr as fl import numpy as np -from flwr.common.configsrecord import ConfigsRecord +from flwr.common import ConfigsRecord SUBSET_SIZE = 1000 STATE_VAR = 'timestamp' @@ -21,14 +21,14 @@ def _record_timestamp_to_state(self): """Record timestamp to client's state.""" t_stamp = datetime.now().timestamp() value = str(t_stamp) - if STATE_VAR in self.context.state.configs.keys(): - value = self.context.state.get_configs(STATE_VAR)[STATE_VAR] # type: ignore + if STATE_VAR in self.context.state.configs_records.keys(): + value = self.context.state.configs_records[STATE_VAR][STATE_VAR] # type: ignore value += f",{t_stamp}" - self.context.state.set_configs(name=STATE_VAR, record=ConfigsRecord({STATE_VAR: value})) + self.context.state.configs_records[STATE_VAR] = ConfigsRecord({STATE_VAR: value}) def _retrieve_timestamp_from_state(self): - return self.context.state.get_configs(STATE_VAR)[STATE_VAR] + return self.context.state.configs_records[STATE_VAR][STATE_VAR] def fit(self, parameters, config): model_params = parameters diff --git a/e2e/bare/driver.py b/e2e/bare/driver.py index d428fe757aa9..defc2ad56213 100644 --- a/e2e/bare/driver.py +++ b/e2e/bare/driver.py @@ -2,7 +2,7 @@ # Start Flower server -hist = fl.server.driver.start_driver( +hist = fl.server.start_driver( server_address="0.0.0.0:9091", config=fl.server.ServerConfig(num_rounds=3), ) diff --git a/e2e/bare/pyproject.toml b/e2e/bare/pyproject.toml index cde8728f5c34..45ce7ea333af 100644 --- a/e2e/bare/pyproject.toml +++ b/e2e/bare/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" name = "bare_test" version = "0.1.0" description = "Bare Federated Learning test with Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = "^3.8" diff --git a/e2e/fastai/driver.py b/e2e/fastai/driver.py index b7b1c41ff5a3..cc452ea523ca 100644 --- a/e2e/fastai/driver.py +++ b/e2e/fastai/driver.py @@ -1,6 +1,6 @@ import flwr as fl -hist = fl.server.driver.start_driver( +hist = fl.server.start_driver( server_address="0.0.0.0:9091", config=fl.server.ServerConfig(num_rounds=3), ) diff --git a/e2e/fastai/pyproject.toml b/e2e/fastai/pyproject.toml index 66fcb393d988..feed31f6d202 100644 --- a/e2e/fastai/pyproject.toml +++ b/e2e/fastai/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" name = "quickstart-fastai" version = "0.1.0" description = "Fastai Federated Learning E2E test with Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.10" diff --git a/e2e/jax/driver.py b/e2e/jax/driver.py index b7b1c41ff5a3..cc452ea523ca 100644 --- a/e2e/jax/driver.py +++ b/e2e/jax/driver.py @@ -1,6 +1,6 @@ import flwr as fl -hist = fl.server.driver.start_driver( +hist = fl.server.start_driver( server_address="0.0.0.0:9091", config=fl.server.ServerConfig(num_rounds=3), ) diff --git a/e2e/jax/pyproject.toml b/e2e/jax/pyproject.toml index 3db32ea855eb..9a4af5dee59a 100644 --- a/e2e/jax/pyproject.toml +++ b/e2e/jax/pyproject.toml @@ -2,7 +2,7 @@ name = "jax_example" version = "0.1.0" description = "JAX example training a linear regression model with federated learning" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = "^3.8" diff --git a/e2e/opacus/driver.py b/e2e/opacus/driver.py index 5bc40800c33c..75acd9ccea24 100644 --- a/e2e/opacus/driver.py +++ b/e2e/opacus/driver.py @@ -1,6 +1,6 @@ import flwr as fl -hist = fl.server.driver.start_driver( +hist = fl.server.start_driver( server_address="0.0.0.0:9091", config=fl.server.ServerConfig(num_rounds=3), ) diff --git a/e2e/opacus/pyproject.toml b/e2e/opacus/pyproject.toml index 1aee7c6ec6d9..ab4a727cc00b 100644 --- a/e2e/opacus/pyproject.toml +++ b/e2e/opacus/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" name = "opacus_e2e" version = "0.1.0" description = "Opacus E2E testing" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = "^3.8" diff --git a/e2e/pandas/driver.py b/e2e/pandas/driver.py index 78120fc946f2..f5dc74c9f3f8 100644 --- a/e2e/pandas/driver.py +++ b/e2e/pandas/driver.py @@ -3,7 +3,7 @@ from strategy import FedAnalytics # Start Flower server -hist = fl.server.driver.start_driver( +hist = fl.server.start_driver( server_address="0.0.0.0:9091", config=fl.server.ServerConfig(num_rounds=1), strategy=FedAnalytics(), diff --git a/e2e/pandas/pyproject.toml b/e2e/pandas/pyproject.toml index b90037a31068..416dfeec3460 100644 --- a/e2e/pandas/pyproject.toml +++ b/e2e/pandas/pyproject.toml @@ -7,7 +7,7 @@ name = "quickstart-pandas" version = "0.1.0" description = "Pandas Federated Analytics Quickstart with Flower" authors = ["Ragy Haddad "] -maintainers = ["The Flower Authors "] +maintainers = ["The Flower Authors "] [tool.poetry.dependencies] python = "^3.8" diff --git a/e2e/pandas/simulation.py b/e2e/pandas/simulation.py index 91af84062712..b548b5ebb760 100644 --- a/e2e/pandas/simulation.py +++ b/e2e/pandas/simulation.py @@ -1,12 +1,8 @@ import flwr as fl -from client import FlowerClient +from client import client_fn from strategy import FedAnalytics -def client_fn(cid): - _ = cid - return FlowerClient() - hist = fl.simulation.start_simulation( client_fn=client_fn, num_clients=2, diff --git a/e2e/pytorch-lightning/driver.py b/e2e/pytorch-lightning/driver.py index b7b1c41ff5a3..cc452ea523ca 100644 --- a/e2e/pytorch-lightning/driver.py +++ b/e2e/pytorch-lightning/driver.py @@ -1,6 +1,6 @@ import flwr as fl -hist = fl.server.driver.start_driver( +hist = fl.server.start_driver( server_address="0.0.0.0:9091", config=fl.server.ServerConfig(num_rounds=3), ) diff --git a/e2e/pytorch-lightning/pyproject.toml b/e2e/pytorch-lightning/pyproject.toml index 951349c03a04..88cddddf500f 100644 --- a/e2e/pytorch-lightning/pyproject.toml +++ b/e2e/pytorch-lightning/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.masonry.api" name = "quickstart-pytorch-lightning" version = "0.1.0" description = "Federated Learning E2E test with Flower and PyTorch Lightning" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = "^3.8" diff --git a/e2e/pytorch/client.py b/e2e/pytorch/client.py index 0f1b8e159f7d..1fd07763148e 100644 --- a/e2e/pytorch/client.py +++ b/e2e/pytorch/client.py @@ -11,7 +11,7 @@ from tqdm import tqdm import flwr as fl -from flwr.common.configsrecord import ConfigsRecord +from flwr.common import ConfigsRecord # ############################################################################# # 1. Regular PyTorch pipeline: nn.Module, train, test, and DataLoader @@ -97,14 +97,14 @@ def _record_timestamp_to_state(self): """Record timestamp to client's state.""" t_stamp = datetime.now().timestamp() value = str(t_stamp) - if STATE_VAR in self.context.state.configs.keys(): - value = self.context.state.get_configs(STATE_VAR)[STATE_VAR] # type: ignore + if STATE_VAR in self.context.state.configs_records.keys(): + value = self.context.state.configs_records[STATE_VAR][STATE_VAR] # type: ignore value += f",{t_stamp}" - self.context.state.set_configs(name=STATE_VAR, record=ConfigsRecord({STATE_VAR: value})) + self.context.state.configs_records[STATE_VAR] = ConfigsRecord({STATE_VAR: value}) def _retrieve_timestamp_from_state(self): - return self.context.state.get_configs(STATE_VAR)[STATE_VAR] + return self.context.state.configs_records[STATE_VAR][STATE_VAR] def fit(self, parameters, config): set_parameters(net, parameters) train(net, trainloader, epochs=1) diff --git a/e2e/pytorch/driver.py b/e2e/pytorch/driver.py index 9f9b076ee75b..2ea4de69a62b 100644 --- a/e2e/pytorch/driver.py +++ b/e2e/pytorch/driver.py @@ -18,7 +18,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=weighted_average) # Start Flower server -hist = fl.server.driver.start_driver( +hist = fl.server.start_driver( server_address="0.0.0.0:9091", config=fl.server.ServerConfig(num_rounds=3), strategy=strategy, diff --git a/e2e/pytorch/pyproject.toml b/e2e/pytorch/pyproject.toml index 4aa326330c62..e538f1437df6 100644 --- a/e2e/pytorch/pyproject.toml +++ b/e2e/pytorch/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" name = "quickstart-pytorch" version = "0.1.0" description = "PyTorch Federated Learning Quickstart with Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = "^3.8" diff --git a/e2e/scikit-learn/driver.py b/e2e/scikit-learn/driver.py index e7ce124e5ead..29051d02c6b6 100644 --- a/e2e/scikit-learn/driver.py +++ b/e2e/scikit-learn/driver.py @@ -36,7 +36,7 @@ def evaluate(server_round, parameters: fl.common.NDArrays, config): evaluate_fn=get_evaluate_fn(model), on_fit_config_fn=fit_round, ) - hist = fl.server.driver.start_driver( + hist = fl.server.start_driver( server_address="0.0.0.0:9091", strategy=strategy, config=fl.server.ServerConfig(num_rounds=3), diff --git a/e2e/scikit-learn/pyproject.toml b/e2e/scikit-learn/pyproject.toml index 372ae1218bbe..50c07d31add7 100644 --- a/e2e/scikit-learn/pyproject.toml +++ b/e2e/scikit-learn/pyproject.toml @@ -7,8 +7,8 @@ name = "sklearn-mnist" version = "0.1.0" description = "Federated learning with scikit-learn and Flower" authors = [ - "The Flower Authors ", - "Kaushik Amar Das " + "The Flower Authors ", + "Kaushik Amar Das ", ] [tool.poetry.dependencies] diff --git a/e2e/strategies/pyproject.toml b/e2e/strategies/pyproject.toml index edfb16de59a6..5cc74b20fa24 100644 --- a/e2e/strategies/pyproject.toml +++ b/e2e/strategies/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" name = "quickstart_tensorflow" version = "0.1.0" description = "Keras Federated Learning Quickstart with Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" diff --git a/e2e/tabnet/driver.py b/e2e/tabnet/driver.py index b7b1c41ff5a3..cc452ea523ca 100644 --- a/e2e/tabnet/driver.py +++ b/e2e/tabnet/driver.py @@ -1,6 +1,6 @@ import flwr as fl -hist = fl.server.driver.start_driver( +hist = fl.server.start_driver( server_address="0.0.0.0:9091", config=fl.server.ServerConfig(num_rounds=3), ) diff --git a/e2e/tabnet/pyproject.toml b/e2e/tabnet/pyproject.toml index 7b47ffeb1470..b1abf382a24a 100644 --- a/e2e/tabnet/pyproject.toml +++ b/e2e/tabnet/pyproject.toml @@ -6,13 +6,13 @@ build-backend = "poetry.core.masonry.api" name = "quickstart-tabnet" version = "0.1.0" description = "Tabnet Federated Learning Quickstart with Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" flwr = { path = "../../", develop = true, extras = ["simulation"] } -tensorflow-cpu = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\""} -tensorflow-macos = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\""} +tensorflow-cpu = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\"" } +tensorflow-macos = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\"" } tensorflow_datasets = "4.9.2" tensorflow-io-gcs-filesystem = "<0.35.0" tabnet = "0.1.6" diff --git a/e2e/tensorflow/driver.py b/e2e/tensorflow/driver.py index 9f9b076ee75b..2ea4de69a62b 100644 --- a/e2e/tensorflow/driver.py +++ b/e2e/tensorflow/driver.py @@ -18,7 +18,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=weighted_average) # Start Flower server -hist = fl.server.driver.start_driver( +hist = fl.server.start_driver( server_address="0.0.0.0:9091", config=fl.server.ServerConfig(num_rounds=3), strategy=strategy, diff --git a/e2e/tensorflow/pyproject.toml b/e2e/tensorflow/pyproject.toml index 467e69a026b3..a7dbfe2305db 100644 --- a/e2e/tensorflow/pyproject.toml +++ b/e2e/tensorflow/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" name = "quickstart-tensorflow" version = "0.1.0" description = "Keras Federated Learning Quickstart with Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" diff --git a/examples/advanced-pytorch/README.md b/examples/advanced-pytorch/README.md index 9101105b2618..c1ba85b95879 100644 --- a/examples/advanced-pytorch/README.md +++ b/examples/advanced-pytorch/README.md @@ -1,6 +1,6 @@ # Advanced Flower Example (PyTorch) -This example demonstrates an advanced federated learning setup using Flower with PyTorch. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) and it differs from the quickstart example in the following ways: +This example demonstrates an advanced federated learning setup using Flower with PyTorch. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) and it differs from the quickstart example in the following ways: - 10 clients (instead of just 2) - Each client holds a local dataset of 5000 training examples and 1000 test examples (note that using the `run.sh` script will only select 10 data samples by default, as the `--toy` argument is set). diff --git a/examples/advanced-pytorch/pyproject.toml b/examples/advanced-pytorch/pyproject.toml index 89fd5a32a89e..b846a6054cc8 100644 --- a/examples/advanced-pytorch/pyproject.toml +++ b/examples/advanced-pytorch/pyproject.toml @@ -7,8 +7,8 @@ name = "advanced-pytorch" version = "0.1.0" description = "Advanced Flower/PyTorch Example" authors = [ - "The Flower Authors ", - "Kaushik Amar Das " + "The Flower Authors ", + "Kaushik Amar Das ", ] [tool.poetry.dependencies] diff --git a/examples/advanced-pytorch/utils.py b/examples/advanced-pytorch/utils.py index 186f079010dc..4a0f6918cdd6 100644 --- a/examples/advanced-pytorch/utils.py +++ b/examples/advanced-pytorch/utils.py @@ -9,10 +9,10 @@ warnings.filterwarnings("ignore") -def load_partition(node_id, toy: bool = False): +def load_partition(partition_id, toy: bool = False): """Load partition CIFAR10 data.""" fds = FederatedDataset(dataset="cifar10", partitioners={"train": 10}) - partition = fds.load_partition(node_id) + partition = fds.load_partition(partition_id) # Divide data on each node: 80% train, 20% test partition_train_test = partition.train_test_split(test_size=0.2) partition_train_test = partition_train_test.with_transform(apply_transforms) diff --git a/examples/advanced-tensorflow/README.md b/examples/advanced-tensorflow/README.md index b21c0d2545ca..59866fd99a06 100644 --- a/examples/advanced-tensorflow/README.md +++ b/examples/advanced-tensorflow/README.md @@ -1,6 +1,6 @@ # Advanced Flower Example (TensorFlow/Keras) -This example demonstrates an advanced federated learning setup using Flower with TensorFlow/Keras. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) and it differs from the quickstart example in the following ways: +This example demonstrates an advanced federated learning setup using Flower with TensorFlow/Keras. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) and it differs from the quickstart example in the following ways: - 10 clients (instead of just 2) - Each client holds a local dataset of 1/10 of the train datasets and 80% is training examples and 20% as test examples (note that by default only a small subset of this data is used when running the `run.sh` script) diff --git a/examples/advanced-tensorflow/pyproject.toml b/examples/advanced-tensorflow/pyproject.toml index 2f16d8a15584..02bd923129a4 100644 --- a/examples/advanced-tensorflow/pyproject.toml +++ b/examples/advanced-tensorflow/pyproject.toml @@ -6,11 +6,11 @@ build-backend = "poetry.core.masonry.api" name = "advanced-tensorflow" version = "0.1.0" description = "Advanced Flower/TensorFlow Example" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" flwr = ">=1.0,<2.0" flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" } -tensorflow-cpu = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\""} -tensorflow-macos = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\""} +tensorflow-cpu = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\"" } +tensorflow-macos = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\"" } diff --git a/examples/android/pyproject.toml b/examples/android/pyproject.toml index 2b9cd8c978a7..0371f7208292 100644 --- a/examples/android/pyproject.toml +++ b/examples/android/pyproject.toml @@ -6,10 +6,10 @@ build-backend = "poetry.masonry.api" name = "android_flwr_tensorflow" version = "0.1.0" description = "Android Example" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" flwr = ">=1.0,<2.0" -tensorflow-cpu = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\""} -tensorflow-macos = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\""} +tensorflow-cpu = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\"" } +tensorflow-macos = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\"" } diff --git a/examples/app-pytorch/README.md b/examples/app-pytorch/README.md new file mode 100644 index 000000000000..de1b6fdbb819 --- /dev/null +++ b/examples/app-pytorch/README.md @@ -0,0 +1,67 @@ +# Flower App (PyTorch) 🧪 + +> 🧪 = This example covers experimental features that might change in future versions of Flower +> Please consult the regular PyTorch code examples ([quickstart](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch), [advanced](https://github.com/adap/flower/tree/main/examples/advanced-pytorch)) to learn how to use Flower with PyTorch. + +The following steps describe how to start a long-running Flower server (SuperLink) and then run a Flower App (consisting of a `ClientApp` and a `ServerApp`). + +## Preconditions + +Let's assume the following project structure: + +```bash +$ tree . +. +├── client.py # <-- contains `ClientApp` +├── server.py # <-- contains `ServerApp` +├── server_workflow.py # <-- contains `ServerApp` with workflow +├── server_custom.py # <-- contains `ServerApp` with custom main function +├── task.py # <-- task-specific code (model, data) +└── requirements.txt # <-- dependencies +``` + +## Install dependencies + +```bash +pip install -r requirements.txt +``` + +## Start the long-running Flower server (SuperLink) + +```bash +flower-superlink --insecure +``` + +## Start the long-running Flower client (SuperNode) + +In a new terminal window, start the first long-running Flower client: + +```bash +flower-client-app client:app --insecure +``` + +In yet another new terminal window, start the second long-running Flower client: + +```bash +flower-client-app client:app --insecure +``` + +## Run the Flower App + +With both the long-running server (SuperLink) and two clients (SuperNode) up and running, we can now run the actual Flower App: + +```bash +flower-server-app server:app --insecure +``` + +Or, to try the workflow example, run: + +```bash +flower-server-app server_workflow:app --insecure +``` + +Or, to try the custom server function example, run: + +```bash +flower-server-app server_custom:app --insecure +``` diff --git a/examples/mt-pytorch/client.py b/examples/app-pytorch/client.py similarity index 78% rename from examples/mt-pytorch/client.py rename to examples/app-pytorch/client.py index 1f2db323ac34..8095a2d7aa93 100644 --- a/examples/mt-pytorch/client.py +++ b/examples/app-pytorch/client.py @@ -38,16 +38,7 @@ def client_fn(cid: str): return FlowerClient().to_client() -# To run this: `flower-client client:app` +# Run via `flower-client-app client:app` app = fl.client.ClientApp( client_fn=client_fn, ) - - -if __name__ == "__main__": - # Start Flower client - fl.client.start_client( - server_address="0.0.0.0:9092", # "0.0.0.0:9093" for REST - client_fn=client_fn, - transport="grpc-rere", # "rest" for REST - ) diff --git a/examples/mt-pytorch/pyproject.toml b/examples/app-pytorch/pyproject.toml similarity index 67% rename from examples/mt-pytorch/pyproject.toml rename to examples/app-pytorch/pyproject.toml index 4978035495ea..bc22d70e9075 100644 --- a/examples/mt-pytorch/pyproject.toml +++ b/examples/app-pytorch/pyproject.toml @@ -3,14 +3,14 @@ requires = ["poetry-core>=1.4.0"] build-backend = "poetry.core.masonry.api" [tool.poetry] -name = "mt-pytorch" +name = "app-pytorch" version = "0.1.0" description = "Multi-Tenant Federated Learning with Flower and PyTorch" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" -flwr-nightly = {version = ">=1.0,<2.0", extras = ["rest", "simulation"]} +flwr = { path = "../../", develop = true, extras = ["simulation"] } torch = "1.13.1" torchvision = "0.14.1" tqdm = "4.65.0" diff --git a/examples/mt-pytorch/requirements.txt b/examples/app-pytorch/requirements.txt similarity index 100% rename from examples/mt-pytorch/requirements.txt rename to examples/app-pytorch/requirements.txt diff --git a/examples/mt-pytorch/start_driver.py b/examples/app-pytorch/server.py similarity index 85% rename from examples/mt-pytorch/start_driver.py rename to examples/app-pytorch/server.py index 307f4ebd1a3b..fbf3f24a133d 100644 --- a/examples/mt-pytorch/start_driver.py +++ b/examples/app-pytorch/server.py @@ -33,10 +33,9 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: fit_metrics_aggregation_fn=weighted_average, ) -if __name__ == "__main__": - # Start Flower server - fl.server.driver.start_driver( - server_address="0.0.0.0:9091", - config=fl.server.ServerConfig(num_rounds=3), - strategy=strategy, - ) + +# Run via `flower-server-app server:app` +app = fl.server.ServerApp( + config=fl.server.ServerConfig(num_rounds=3), + strategy=strategy, +) diff --git a/examples/app-pytorch/server_custom.py b/examples/app-pytorch/server_custom.py new file mode 100644 index 000000000000..1f0cb0d26d93 --- /dev/null +++ b/examples/app-pytorch/server_custom.py @@ -0,0 +1,154 @@ +from typing import List, Tuple, Dict +import random +import time + +import flwr as fl +from flwr.server import Driver +from flwr.common import Context + +from flwr.common import ( + ServerMessage, + FitIns, + ndarrays_to_parameters, + serde, + parameters_to_ndarrays, + ClientMessage, + NDArrays, + Code, +) +from flwr.proto import driver_pb2, task_pb2, node_pb2, transport_pb2 +from flwr.server.strategy.aggregate import aggregate +from flwr.common import Metrics +from flwr.server import History +from flwr.common import serde +from task import Net, get_parameters, set_parameters +from flwr.common.recordset_compat import fitins_to_recordset, recordset_to_fitres +from flwr.common import Message + + +# Define metric aggregation function +def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: + examples = [num_examples for num_examples, _ in metrics] + + # Multiply accuracy of each client by number of examples used + train_losses = [num_examples * m["train_loss"] for num_examples, m in metrics] + train_accuracies = [ + num_examples * m["train_accuracy"] for num_examples, m in metrics + ] + val_losses = [num_examples * m["val_loss"] for num_examples, m in metrics] + val_accuracies = [num_examples * m["val_accuracy"] for num_examples, m in metrics] + + # Aggregate and return custom metric (weighted average) + return { + "train_loss": sum(train_losses) / sum(examples), + "train_accuracy": sum(train_accuracies) / sum(examples), + "val_loss": sum(val_losses) / sum(examples), + "val_accuracy": sum(val_accuracies) / sum(examples), + } + + +# Run via `flower-server-app server:app` +app = fl.server.ServerApp() + + +@app.main() +def main(driver: Driver, context: Context) -> None: + """.""" + print("RUNNING!!!!!") + + anonymous_client_nodes = False + num_client_nodes_per_round = 2 + sleep_time = 1 + num_rounds = 3 + parameters = ndarrays_to_parameters(get_parameters(net=Net())) + + history = History() + for server_round in range(num_rounds): + print(f"Commencing server round {server_round + 1}") + + # List of sampled node IDs in this round + sampled_nodes: List[int] = [] + + # The Driver API might not immediately return enough client node IDs, so we + # loop and wait until enough client nodes are available. + while True: + all_node_ids = driver.get_node_ids() + + print(f"Got {len(all_node_ids)} client nodes: {all_node_ids}") + if len(all_node_ids) >= num_client_nodes_per_round: + # Sample client nodes + sampled_nodes = random.sample(all_node_ids, num_client_nodes_per_round) + break + time.sleep(3) + + # Log sampled node IDs + print(f"Sampled {len(sampled_nodes)} node IDs: {sampled_nodes}") + + # Schedule a task for all sampled nodes + fit_ins: FitIns = FitIns(parameters=parameters, config={}) + recordset = fitins_to_recordset(fitins=fit_ins, keep_input=True) + + messages = [] + for node_id in sampled_nodes: + message = driver.create_message( + content=recordset, + message_type="fit", + dst_node_id=node_id, + group_id=str(server_round), + ttl="", + ) + messages.append(message) + + message_ids = driver.push_messages(messages) + print(f"Pushed {len(message_ids)} messages: {message_ids}") + + # Wait for results, ignore empty message_ids + message_ids = [message_id for message_id in message_ids if message_id != ""] + + all_replies: List[Message] = [] + while True: + replies = driver.pull_messages(message_ids=message_ids) + print(f"Got {len(replies)} results") + all_replies += replies + if len(all_replies) == len(message_ids): + break + time.sleep(3) + + # Collect correct results + all_fitres = [ + recordset_to_fitres(msg.content, keep_input=True) for msg in all_replies + ] + print(f"Received {len(all_fitres)} results") + + weights_results: List[Tuple[NDArrays, int]] = [] + metrics_results: List[Tuple[int, Dict]] = [] + for fitres in all_fitres: + print(f"num_examples: {fitres.num_examples}, status: {fitres.status.code}") + + # Aggregate only if the status is OK + if fitres.status.code != Code.OK: + continue + weights_results.append( + (parameters_to_ndarrays(fitres.parameters), fitres.num_examples) + ) + metrics_results.append((fitres.num_examples, fitres.metrics)) + + # Aggregate parameters (FedAvg) + parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results)) + parameters = parameters_aggregated + + # Aggregate metrics + metrics_aggregated = weighted_average(metrics_results) + history.add_metrics_distributed_fit( + server_round=server_round, metrics=metrics_aggregated + ) + print("Round ", server_round, " metrics: ", metrics_aggregated) + + # Slow down the start of the next round + time.sleep(sleep_time) + + print("app_fit: losses_distributed %s", str(history.losses_distributed)) + print("app_fit: metrics_distributed_fit %s", str(history.metrics_distributed_fit)) + print("app_fit: metrics_distributed %s", str(history.metrics_distributed)) + print("app_fit: losses_centralized %s", str(history.losses_centralized)) + print("app_fit: metrics_centralized %s", str(history.metrics_centralized)) diff --git a/examples/app-pytorch/server_workflow.py b/examples/app-pytorch/server_workflow.py new file mode 100644 index 000000000000..920e266c99e9 --- /dev/null +++ b/examples/app-pytorch/server_workflow.py @@ -0,0 +1,55 @@ +from typing import List, Tuple + +import flwr as fl +from flwr.common import Context, Metrics +from flwr.server import Driver, LegacyContext + + +# Define metric aggregation function +def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: + examples = [num_examples for num_examples, _ in metrics] + + # Multiply accuracy of each client by number of examples used + train_losses = [num_examples * m["train_loss"] for num_examples, m in metrics] + train_accuracies = [ + num_examples * m["train_accuracy"] for num_examples, m in metrics + ] + val_losses = [num_examples * m["val_loss"] for num_examples, m in metrics] + val_accuracies = [num_examples * m["val_accuracy"] for num_examples, m in metrics] + + # Aggregate and return custom metric (weighted average) + return { + "train_loss": sum(train_losses) / sum(examples), + "train_accuracy": sum(train_accuracies) / sum(examples), + "val_loss": sum(val_losses) / sum(examples), + "val_accuracy": sum(val_accuracies) / sum(examples), + } + + +# Define strategy +strategy = fl.server.strategy.FedAvg( + fraction_fit=1.0, # Select all available clients + fraction_evaluate=0.0, # Disable evaluation + min_available_clients=2, + fit_metrics_aggregation_fn=weighted_average, +) + + +# Run via `flower-server-app server_workflow:app` +app = fl.server.ServerApp() + + +@app.main() +def main(driver: Driver, context: Context) -> None: + # Construct the LegacyContext + context = LegacyContext( + state=context.state, + config=fl.server.ServerConfig(num_rounds=3), + strategy=strategy, + ) + + # Create the workflow + workflow = fl.server.workflow.DefaultWorkflow() + + # Execute + workflow(driver, context) diff --git a/examples/mt-pytorch/task.py b/examples/app-pytorch/task.py similarity index 100% rename from examples/mt-pytorch/task.py rename to examples/app-pytorch/task.py diff --git a/examples/custom-metrics/README.md b/examples/custom-metrics/README.md index debcd7919839..317fb6336106 100644 --- a/examples/custom-metrics/README.md +++ b/examples/custom-metrics/README.md @@ -9,7 +9,7 @@ The main takeaways of this implementation are: - the use of the `output_dict` on the client side - inside `evaluate` method on `client.py` - the use of the `evaluate_metrics_aggregation_fn` - to aggregate the metrics on the server side, part of the `strategy` on `server.py` -This example is based on the `quickstart-tensorflow` with CIFAR-10, source [here](https://flower.dev/docs/quickstart-tensorflow.html), with the addition of [Flower Datasets](https://flower.dev/docs/datasets/index.html) to retrieve the CIFAR-10. +This example is based on the `quickstart-tensorflow` with CIFAR-10, source [here](https://flower.ai/docs/quickstart-tensorflow.html), with the addition of [Flower Datasets](https://flower.ai/docs/datasets/index.html) to retrieve the CIFAR-10. Using the CIFAR-10 dataset for classification, this is a multi-class classification problem, thus some changes on how to calculate the metrics using `average='micro'` and `np.argmax` is required. For binary classification, this is not required. Also, for unsupervised learning tasks, such as using a deep autoencoder, a custom metric based on reconstruction error could be implemented on client side. @@ -91,16 +91,16 @@ chmod +x run.sh ./run.sh ``` -You will see that Keras is starting a federated training. Have a look to the [Flower Quickstarter documentation](https://flower.dev/docs/quickstart-tensorflow.html) for a detailed explanation. You can add `steps_per_epoch=3` to `model.fit()` if you just want to evaluate that everything works without having to wait for the client-side training to finish (this will save you a lot of time during development). +You will see that Keras is starting a federated training. Have a look to the [Flower Quickstarter documentation](https://flower.ai/docs/quickstart-tensorflow.html) for a detailed explanation. You can add `steps_per_epoch=3` to `model.fit()` if you just want to evaluate that everything works without having to wait for the client-side training to finish (this will save you a lot of time during development). Running `run.sh` will result in the following output (after 3 rounds): ```shell INFO flwr 2024-01-17 17:45:23,794 | app.py:228 | app_fit: metrics_distributed { - 'accuracy': [(1, 0.10000000149011612), (2, 0.10000000149011612), (3, 0.3393000066280365)], - 'acc': [(1, 0.1), (2, 0.1), (3, 0.3393)], - 'rec': [(1, 0.1), (2, 0.1), (3, 0.3393)], - 'prec': [(1, 0.1), (2, 0.1), (3, 0.3393)], + 'accuracy': [(1, 0.10000000149011612), (2, 0.10000000149011612), (3, 0.3393000066280365)], + 'acc': [(1, 0.1), (2, 0.1), (3, 0.3393)], + 'rec': [(1, 0.1), (2, 0.1), (3, 0.3393)], + 'prec': [(1, 0.1), (2, 0.1), (3, 0.3393)], 'f1': [(1, 0.10000000000000002), (2, 0.10000000000000002), (3, 0.3393)] } ``` diff --git a/examples/custom-metrics/pyproject.toml b/examples/custom-metrics/pyproject.toml index 8a2da6562018..51c29e213d81 100644 --- a/examples/custom-metrics/pyproject.toml +++ b/examples/custom-metrics/pyproject.toml @@ -7,8 +7,8 @@ name = "custom-metrics" version = "0.1.0" description = "Federated Learning with Flower and Custom Metrics" authors = [ - "The Flower Authors ", - "Gustavo Bertoli " + "The Flower Authors ", + "Gustavo Bertoli ", ] [tool.poetry.dependencies] @@ -16,4 +16,4 @@ python = ">=3.8,<3.11" flwr = ">=1.0,<2.0" flwr-datasets = { version = "*", extras = ["vision"] } scikit-learn = "^1.2.2" -tensorflow = "==2.12.0" \ No newline at end of file +tensorflow = "==2.12.0" diff --git a/examples/doc/source/_templates/base.html b/examples/doc/source/_templates/base.html index e4fe80720b74..08030fb08c15 100644 --- a/examples/doc/source/_templates/base.html +++ b/examples/doc/source/_templates/base.html @@ -5,7 +5,7 @@ - + {%- if metatags %}{{ metatags }}{% endif -%} @@ -99,6 +99,6 @@ {%- endblock -%} {%- endblock scripts -%} - + diff --git a/examples/doc/source/conf.py b/examples/doc/source/conf.py index 608aaeaeed6b..bf177aa5ae24 100644 --- a/examples/doc/source/conf.py +++ b/examples/doc/source/conf.py @@ -76,7 +76,7 @@ html_title = f"Flower Examples {release}" html_logo = "_static/flower-logo.png" html_favicon = "_static/favicon.ico" -html_baseurl = "https://flower.dev/docs/examples/" +html_baseurl = "https://flower.ai/docs/examples/" html_theme_options = { # diff --git a/examples/dp-sgd-mnist/pyproject.toml b/examples/dp-sgd-mnist/pyproject.toml index 3158ab3d52f6..161952fd2aa4 100644 --- a/examples/dp-sgd-mnist/pyproject.toml +++ b/examples/dp-sgd-mnist/pyproject.toml @@ -7,14 +7,14 @@ name = "dp-sgd-mnist" version = "0.1.0" description = "Federated training with Tensorflow Privacy" authors = [ - "The Flower Authors ", - "Kaushik Amar Das " + "The Flower Authors ", + "Kaushik Amar Das ", ] [tool.poetry.dependencies] python = ">=3.8,<3.11" # flwr = { path = "../../", develop = true } # Development flwr = ">=1.0,<2.0" -tensorflow-cpu = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\""} -tensorflow-macos = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\""} +tensorflow-cpu = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\"" } +tensorflow-macos = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\"" } tensorflow-privacy = "0.8.10" diff --git a/examples/embedded-devices/client_pytorch.py b/examples/embedded-devices/client_pytorch.py index f326db7c678c..3f1e6c7d51b7 100644 --- a/examples/embedded-devices/client_pytorch.py +++ b/examples/embedded-devices/client_pytorch.py @@ -105,8 +105,8 @@ def apply_transforms(batch): trainsets = [] validsets = [] - for node_id in range(NUM_CLIENTS): - partition = fds.load_partition(node_id, "train") + for partition_id in range(NUM_CLIENTS): + partition = fds.load_partition(partition_id, "train") # Divide data on each node: 90% train, 10% test partition = partition.train_test_split(test_size=0.1) partition = partition.with_transform(apply_transforms) diff --git a/examples/embedded-devices/client_tf.py b/examples/embedded-devices/client_tf.py index ae793ecd81e0..d59b31ab1569 100644 --- a/examples/embedded-devices/client_tf.py +++ b/examples/embedded-devices/client_tf.py @@ -40,8 +40,8 @@ def prepare_dataset(use_mnist: bool): fds = FederatedDataset(dataset="cifar10", partitioners={"train": NUM_CLIENTS}) img_key = "img" partitions = [] - for node_id in range(NUM_CLIENTS): - partition = fds.load_partition(node_id, "train") + for partition_id in range(NUM_CLIENTS): + partition = fds.load_partition(partition_id, "train") partition.set_format("numpy") # Divide data on each node: 90% train, 10% test partition = partition.train_test_split(test_size=0.1) diff --git a/examples/federated-kaplan-meier-fitter/README.md b/examples/federated-kaplan-meier-fitter/README.md new file mode 100644 index 000000000000..1569467d6f82 --- /dev/null +++ b/examples/federated-kaplan-meier-fitter/README.md @@ -0,0 +1,104 @@ +# Flower Example using KaplanMeierFitter + +This is an introductory example on **federated survival analysis** using [Flower](https://flower.ai/) +and [lifelines](https://lifelines.readthedocs.io/en/stable/index.html) library. + +The aim of this example is to estimate the survival function using the +[Kaplan-Meier Estimate](https://en.wikipedia.org/wiki/Kaplan%E2%80%93Meier_estimator) implemented in +lifelines library (see [KaplanMeierFitter](https://lifelines.readthedocs.io/en/stable/fitters/univariate/KaplanMeierFitter.html#lifelines.fitters.kaplan_meier_fitter.KaplanMeierFitter)). The distributed/federated aspect of this example +is the data sending to the server. You can think of it as a federated analytics example. However, it's worth noting that this procedure violates privacy since the raw data is exchanged. + +Finally, many other estimators beyond KaplanMeierFitter can be used with the provided strategy: +AalenJohansenFitter, GeneralizedGammaFitter, LogLogisticFitter, +SplineFitter, and WeibullFitter. + +We also use the [NatualPartitioner](https://flower.ai/docs/datasets/ref-api/flwr_datasets.partitioner.NaturalIdPartitioner.html#flwr_datasets.partitioner.NaturalIdPartitioner) from [Flower Datasets](https://flower.ai/docs/datasets/) to divide the data according to +the group it comes from therefore to simulate the division that might occur. + +

+Survival Function +

+ +## Project Setup + +Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you: + +```shell +$ git clone --depth=1 https://github.com/adap/flower.git _tmp && mv _tmp/examples/federated-kaplan-meier-fitter . && rm -rf _tmp && cd federated-kaplan-meier-fitter +``` + +This will create a new directory called `federated-kaplan-meier-fitter` containing the following files: + +```shell +-- pyproject.toml +-- requirements.txt +-- client.py +-- server.py +-- centralized.py +-- README.md +``` + +### Installing Dependencies + +Project dependencies (such as `lifelines` and `flwr`) are defined in `pyproject.toml` and `requirements.txt`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)) or [pip](https://pip.pypa.io/en/latest/development/), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences. + +#### Poetry + +```shell +poetry install +poetry shell +``` + +Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command: + +```shell +poetry run python3 -c "import flwr" +``` + +If you don't see any errors you're good to go! + +#### pip + +Write the command below in your terminal to install the dependencies according to the configuration file requirements.txt. + +```shell +pip install -r requirements.txt +``` + +## Run Federated Survival Analysis with Flower and lifelines's KaplanMeierFitter + +### Start the long-running Flower server (SuperLink) + +```bash +flower-superlink --insecure +``` + +### Start the long-running Flower client (SuperNode) + +In a new terminal window, start the first long-running Flower client: + +```bash +flower-client-app client:node_1_app --insecure +``` + +In yet another new terminal window, start the second long-running Flower client: + +```bash +flower-client-app client:node_2_app --insecure +``` + +### Run the Flower App + +With both the long-running server (SuperLink) and two clients (SuperNode) up and running, we can now run the actual Flower App: + +```bash +flower-server-app server:app --insecure +``` + +You will see that the server is printing survival function, median survival time and saves the plot with the survival function. + +You can also check that the results match the centralized version. + +```shell +$ python3 centralized.py +``` diff --git a/examples/federated-kaplan-meier-fitter/_static/survival_function_centralized.png b/examples/federated-kaplan-meier-fitter/_static/survival_function_centralized.png new file mode 100644 index 000000000000..b7797f0879f2 Binary files /dev/null and b/examples/federated-kaplan-meier-fitter/_static/survival_function_centralized.png differ diff --git a/examples/federated-kaplan-meier-fitter/_static/survival_function_federated.png b/examples/federated-kaplan-meier-fitter/_static/survival_function_federated.png new file mode 100644 index 000000000000..b7797f0879f2 Binary files /dev/null and b/examples/federated-kaplan-meier-fitter/_static/survival_function_federated.png differ diff --git a/examples/federated-kaplan-meier-fitter/centralized.py b/examples/federated-kaplan-meier-fitter/centralized.py new file mode 100644 index 000000000000..c94edfef9a18 --- /dev/null +++ b/examples/federated-kaplan-meier-fitter/centralized.py @@ -0,0 +1,16 @@ +import matplotlib.pyplot as plt +from lifelines import KaplanMeierFitter +from lifelines.datasets import load_waltons + +if __name__ == "__main__": + X = load_waltons() + fitter = KaplanMeierFitter() + fitter.fit(X["T"], X["E"]) + print("Survival function") + print(fitter.survival_function_) + print("Mean survival time:") + print(fitter.median_survival_time_) + fitter.plot_survival_function() + plt.title("Survival function of fruit flies (Walton's data)", fontsize=16) + plt.savefig("./_static/survival_function_centralized.png", dpi=200) + print("Centralized survival function saved.") diff --git a/examples/federated-kaplan-meier-fitter/client.py b/examples/federated-kaplan-meier-fitter/client.py new file mode 100644 index 000000000000..948492efc575 --- /dev/null +++ b/examples/federated-kaplan-meier-fitter/client.py @@ -0,0 +1,65 @@ +from typing import Dict, List, Tuple + +import flwr as fl +import numpy as np +from datasets import Dataset +from flwr.common import NDArray, NDArrays +from flwr_datasets.partitioner import NaturalIdPartitioner +from lifelines.datasets import load_waltons + + +class FlowerClient(fl.client.NumPyClient): + """Flower client that holds and sends the events and times data. + + Parameters + ---------- + times: NDArray + Times of the `events`. + events: NDArray + Events represented by 0 - no event, 1 - event occurred. + + Raises + ------ + ValueError + If the `times` and `events` are not the same shape. + """ + + def __init__(self, times: NDArray, events: NDArray): + if len(times) != len(events): + raise ValueError("The times and events arrays have to be same shape.") + self._times = times + self._events = events + + def fit( + self, parameters: List[np.ndarray], config: Dict[str, str] + ) -> Tuple[NDArrays, int, Dict]: + return ( + [self._times, self._events], + len(self._times), + {}, + ) + + +# Prepare data +X = load_waltons() +partitioner = NaturalIdPartitioner(partition_by="group") +partitioner.dataset = Dataset.from_pandas(X) + + +def get_client_fn(partition_id: int): + def client_fn(cid: str): + partition = partitioner.load_partition(partition_id).to_pandas() + events = partition["E"].values + times = partition["T"].values + return FlowerClient(times=times, events=events).to_client() + + return client_fn + + +# Run via `flower-client-app client:app` +node_1_app = fl.client.ClientApp( + client_fn=get_client_fn(0), +) +node_2_app = fl.client.ClientApp( + client_fn=get_client_fn(1), +) diff --git a/examples/federated-kaplan-meier-fitter/pyproject.toml b/examples/federated-kaplan-meier-fitter/pyproject.toml new file mode 100644 index 000000000000..8fe354ffb750 --- /dev/null +++ b/examples/federated-kaplan-meier-fitter/pyproject.toml @@ -0,0 +1,18 @@ +[build-system] +requires = ["poetry-core>=1.4.0"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "federated-kaplan-meier-fitter" +version = "0.1.0" +description = "Federated Kaplan Meier Fitter with Flower" +authors = ["The Flower Authors "] +maintainers = ["The Flower Authors "] + +[tool.poetry.dependencies] +python = ">=3.9,<3.11" +flwr-nightly = "*" +flwr-datasets = ">=0.0.2,<1.0.0" +numpy = ">=1.23.2" +pandas = ">=2.0.0" +lifelines = ">=0.28.0" diff --git a/examples/federated-kaplan-meier-fitter/requirements.txt b/examples/federated-kaplan-meier-fitter/requirements.txt new file mode 100644 index 000000000000..cc8146545c7b --- /dev/null +++ b/examples/federated-kaplan-meier-fitter/requirements.txt @@ -0,0 +1,5 @@ +flwr-nightly +flwr-datasets>=0.0.2, <1.0.0 +numpy>=1.23.2 +pandas>=2.0.0 +lifelines>=0.28.0 diff --git a/examples/federated-kaplan-meier-fitter/server.py b/examples/federated-kaplan-meier-fitter/server.py new file mode 100644 index 000000000000..141504ab59c0 --- /dev/null +++ b/examples/federated-kaplan-meier-fitter/server.py @@ -0,0 +1,145 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Strategy that supports many univariate fitters from lifelines library.""" + +from typing import Dict, List, Optional, Tuple, Union, Any + +import numpy as np +import flwr as fl +import matplotlib.pyplot as plt +from flwr.common import ( + FitIns, + Parameters, + Scalar, + EvaluateRes, + EvaluateIns, + FitRes, + parameters_to_ndarrays, +) +from flwr.server.client_manager import ClientManager +from flwr.server.client_proxy import ClientProxy +from flwr.server.strategy import Strategy +from lifelines import KaplanMeierFitter + + +class EventTimeFitterStrategy(Strategy): + """Federated strategy to aggregate the data that consist of events and times. + + It works with the following uni-variate fitters from the lifelines library: + AalenJohansenFitter, GeneralizedGammaFitter, KaplanMeierFitter, LogLogisticFitter, + SplineFitter, WeibullFitter. Note that each of them might require slightly different + initialization but constructed fitter object are required to be passed. + + This strategy recreates the event and time data based on the data received from the + nodes. + + Parameters + ---------- + min_num_clients: int + pass + fitter: Any + uni-variate fitter from lifelines library that works with event, time data e.g. + KaplanMeierFitter + """ + + def __init__(self, min_num_clients: int, fitter: Any): + # Fitter can be access after the federated training ends + self._min_num_clients = min_num_clients + self.fitter = fitter + + def configure_fit( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, FitIns]]: + """Configure the fit method.""" + config = {} + fit_ins = FitIns(parameters, config) + clients = client_manager.sample( + num_clients=client_manager.num_available(), + min_num_clients=self._min_num_clients, + ) + return [(client, fit_ins) for client in clients] + + def aggregate_fit( + self, + server_round: int, + results: List[Tuple[ClientProxy, FitRes]], + failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], + ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + """Merge data and perform the fitting of the fitter from lifelines library. + + Assume just a single federated learning round. Assume the data comes as a list + with two elements of 1-dim numpy arrays of events and times. + """ + remote_data = [ + (parameters_to_ndarrays(fit_res.parameters)) for _, fit_res in results + ] + + combined_times = remote_data[0][0] + combined_events = remote_data[0][1] + + for te_list in remote_data[1:]: + combined_times = np.concatenate((combined_times, te_list[0])) + combined_events = np.concatenate((combined_events, te_list[1])) + + args_sorted = np.argsort(combined_times) + sorted_times = combined_times[args_sorted] + sorted_events = combined_events[args_sorted] + self.fitter.fit(sorted_times, sorted_events) + print("Survival function:") + print(self.fitter.survival_function_) + self.fitter.plot_survival_function() + plt.title("Survival function of fruit flies (Walton's data)", fontsize=16) + plt.savefig("./_static/survival_function_federated.png", dpi=200) + print("Mean survival time:") + print(self.fitter.median_survival_time_) + return None, {} + + # The methods below return None or empty results. + # They need to be implemented to since the methods are abstract in the parent class + def initialize_parameters( + self, client_manager: Optional[ClientManager] = None + ) -> Optional[Parameters]: + """No parameter initialization is needed.""" + return None + + def evaluate( + self, server_round: int, parameters: Parameters + ) -> Optional[Tuple[float, Dict[str, Scalar]]]: + """No centralized evaluation.""" + return None + + def aggregate_evaluate( + self, + server_round: int, + results: List[Tuple[ClientProxy, EvaluateRes]], + failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> Tuple[Optional[float], Dict[str, Scalar]]: + """No federated evaluation.""" + return None, {} + + def configure_evaluate( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, EvaluateIns]]: + """No federated evaluation.""" + return [] + + +fitter = KaplanMeierFitter() # You can choose other method that work on E, T data +strategy = EventTimeFitterStrategy(min_num_clients=2, fitter=fitter) + +app = fl.server.ServerApp( + config=fl.server.ServerConfig(num_rounds=1), + strategy=strategy, +) diff --git a/examples/flower-in-30-minutes/tutorial.ipynb b/examples/flower-in-30-minutes/tutorial.ipynb index 8f9eccf65b74..0e42cff924e8 100644 --- a/examples/flower-in-30-minutes/tutorial.ipynb +++ b/examples/flower-in-30-minutes/tutorial.ipynb @@ -7,11 +7,11 @@ "source": [ "Welcome to the 30 minutes Flower federated learning tutorial!\n", "\n", - "In this tutorial you will implement your first Federated Learning project using [Flower](https://flower.dev/).\n", + "In this tutorial you will implement your first Federated Learning project using [Flower](https://flower.ai/).\n", "\n", "🧑‍🏫 This tutorial starts at zero and expects no familiarity with federated learning. Only a basic understanding of data science and Python programming is assumed. A minimal understanding of ML is not required but if you already know about it, nothing is stopping your from modifying this code as you see fit!\n", "\n", - "> Star Flower on [GitHub ⭐️](https://github.com/adap/flower) and join the Flower community on Slack to connect, ask questions, and get help: [Join Slack 🌼](https://flower.dev/join-slack/). We'd love to hear from you in the #introductions channel! And if anything is unclear, head over to the #questions channel.\n", + "> Star Flower on [GitHub ⭐️](https://github.com/adap/flower) and join the Flower community on Slack to connect, ask questions, and get help: [Join Slack 🌼](https://flower.ai/join-slack/). We'd love to hear from you in the #introductions channel! And if anything is unclear, head over to the #questions channel.\n", "\n", "Let's get stated!" ] @@ -59,7 +59,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We will be using the _simulation_ model in Flower, which allows you to run a large number of clients without the overheads of manually managing devices. This is achieved via the `Virtual Client Engine`, the core component that runs [FL Simulations](https://flower.dev/docs/framework/how-to-run-simulations.html) with Flower. With simulation, you can dynamically scale your experiments whether you run the code on your laptop, a machine with a single GPU, a server with multiple GPUs os even on a cluster with multiple servers. The `Virtual Client Engine` handles everything transparently and it allows you to specify how many resources (e.g. CPU cores, GPU VRAM) should be assigned to each virtual client." + "We will be using the _simulation_ model in Flower, which allows you to run a large number of clients without the overheads of manually managing devices. This is achieved via the `Virtual Client Engine`, the core component that runs [FL Simulations](https://flower.ai/docs/framework/how-to-run-simulations.html) with Flower. With simulation, you can dynamically scale your experiments whether you run the code on your laptop, a machine with a single GPU, a server with multiple GPUs os even on a cluster with multiple servers. The `Virtual Client Engine` handles everything transparently and it allows you to specify how many resources (e.g. CPU cores, GPU VRAM) should be assigned to each virtual client." ] }, { @@ -69,7 +69,7 @@ "source": [ "## Install your ML framework\n", "\n", - "Flower is agnostic to your choice of ML Framework. Flower works with `PyTorch`, `Tensorflow`, `NumPy`, `🤗 Transformers`, `MXNet`, `JAX`, `scikit-learn`, `fastai`, `Pandas`. Flower also supports all major platforms: `iOS`, `Android` and plain `C++`. You can find a _quickstart- example for each of the above in the [Flower Repository](https://github.com/adap/flower/tree/main/examples) inside the `examples/` directory. And check the [Flower Documentation](https://flower.dev/docs/) for even more learning materials.\n", + "Flower is agnostic to your choice of ML Framework. Flower works with `PyTorch`, `Tensorflow`, `NumPy`, `🤗 Transformers`, `MXNet`, `JAX`, `scikit-learn`, `fastai`, `Pandas`. Flower also supports all major platforms: `iOS`, `Android` and plain `C++`. You can find a _quickstart- example for each of the above in the [Flower Repository](https://github.com/adap/flower/tree/main/examples) inside the `examples/` directory. And check the [Flower Documentation](https://flower.ai/docs/) for even more learning materials.\n", "\n", "In this tutorial we are going to use PyTorch, so let's install a recent version. In this tutorial we'll use a small model so using CPU only training will suffice (this will also prevent Colab from abruptly terminating your experiment if resource limits are exceeded)" ] @@ -631,11 +631,11 @@ "\n", "\n", "class FlowerClient(fl.client.NumPyClient):\n", - " def __init__(self, trainloader, vallodaer) -> None:\n", + " def __init__(self, trainloader, valloader) -> None:\n", " super().__init__()\n", "\n", " self.trainloader = trainloader\n", - " self.valloader = vallodaer\n", + " self.valloader = valloader\n", " self.model = Net(num_classes=10)\n", "\n", " def set_parameters(self, parameters):\n", @@ -714,7 +714,7 @@ "metadata": {}, "outputs": [], "source": [ - "def get_evalulate_fn(testloader):\n", + "def get_evaluate_fn(testloader):\n", " \"\"\"This is a function that returns a function. The returned\n", " function (i.e. `evaluate_fn`) will be executed by the strategy\n", " at the end of each round to evaluate the stat of the global\n", @@ -747,7 +747,7 @@ " fraction_fit=0.1, # let's sample 10% of the client each round to do local training\n", " fraction_evaluate=0.1, # after each round, let's sample 20% of the clients to asses how well the global model is doing\n", " min_available_clients=100, # total number of clients available in the experiment\n", - " evaluate_fn=get_evalulate_fn(testloader),\n", + " evaluate_fn=get_evaluate_fn(testloader),\n", ") # a callback to a function that the strategy can execute to evaluate the state of the global model on a centralised dataset" ] }, @@ -775,7 +775,7 @@ " \"\"\"Returns a FlowerClient containing the cid-th data partition\"\"\"\n", "\n", " return FlowerClient(\n", - " trainloader=trainloaders[int(cid)], vallodaer=valloaders[int(cid)]\n", + " trainloader=trainloaders[int(cid)], valloader=valloaders[int(cid)]\n", " ).to_client()\n", "\n", " return client_fn\n", @@ -853,21 +853,21 @@ "\n", "Well, if you enjoyed this content, consider giving us a ⭐️ on GitHub -> https://github.com/adap/flower\n", "\n", - "* **[DOCS]** How about running your Flower clients on the GPU? find out how to do it in the [Flower Simulation Documentation](https://flower.dev/docs/framework/how-to-run-simulations.html)\n", + "* **[DOCS]** How about running your Flower clients on the GPU? find out how to do it in the [Flower Simulation Documentation](https://flower.ai/docs/framework/how-to-run-simulations.html)\n", "\n", "* **[VIDEO]** You can follow our [detailed line-by-line 9-videos tutorial](https://www.youtube.com/watch?v=cRebUIGB5RU&list=PLNG4feLHqCWlnj8a_E1A_n5zr2-8pafTB) about everything you need to know to design your own Flower Simulation pipelines\n", "\n", "* Check more advanced simulation examples the Flower GitHub:\n", "\n", " * Flower simulation with Tensorflow/Keras: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://github.com/adap/flower/tree/main/examples/simulation-tensorflow)\n", - " \n", + "\n", " * Flower simulation with Pytorch: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://github.com/adap/flower/tree/main/examples/simulation-pytorch)\n", "\n", - "* **[DOCS]** All Flower examples: https://flower.dev/docs/examples/\n", + "* **[DOCS]** All Flower examples: https://flower.ai/docs/examples/\n", "\n", "* **[VIDEO]** Our Youtube channel: https://www.youtube.com/@flowerlabs\n", "\n", - "Don't forget to join our Slack channel: https://flower.dev/join-slack/\n" + "Don't forget to join our Slack channel: https://flower.ai/join-slack/\n" ] }, { diff --git a/examples/flower-via-docker-compose/README.md b/examples/flower-via-docker-compose/README.md index 1d830e46cbdb..3ef1ac37bcda 100644 --- a/examples/flower-via-docker-compose/README.md +++ b/examples/flower-via-docker-compose/README.md @@ -1,7 +1,7 @@ # Leveraging Flower and Docker for Device Heterogeneity Management in Federated Learning

- Flower Website + Flower Website Docker Logo

@@ -141,7 +141,7 @@ By following these steps, you will have a fully functional federated learning en ### Data Pipeline with FLWR-Datasets -We have integrated [`flwr-datasets`](https://flower.dev/docs/datasets/) into our data pipeline, which is managed within the `load_data.py` file in the `helpers/` directory. This script facilitates standardized access to datasets across the federated network and incorporates a `data_sampling_percentage` argument. This argument allows users to specify the percentage of the dataset to be used for training and evaluation, accommodating devices with lower memory capabilities to prevent Out-of-Memory (OOM) errors. +We have integrated [`flwr-datasets`](https://flower.ai/docs/datasets/) into our data pipeline, which is managed within the `load_data.py` file in the `helpers/` directory. This script facilitates standardized access to datasets across the federated network and incorporates a `data_sampling_percentage` argument. This argument allows users to specify the percentage of the dataset to be used for training and evaluation, accommodating devices with lower memory capabilities to prevent Out-of-Memory (OOM) errors. ### Model Selection and Dataset diff --git a/examples/ios/pyproject.toml b/examples/ios/pyproject.toml index c1bdbb815bd5..2e55b14cf761 100644 --- a/examples/ios/pyproject.toml +++ b/examples/ios/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" name = "flwr_ios_coreml" version = "0.1.0" description = "Example Server for Flower iOS/CoreML" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" diff --git a/examples/mt-pytorch/README.md b/examples/mt-pytorch/README.md deleted file mode 100644 index 721a26ed814d..000000000000 --- a/examples/mt-pytorch/README.md +++ /dev/null @@ -1,52 +0,0 @@ -# Flower App (PyTorch) 🧪 - -🧪 = This example covers experimental features that might change in future versions of Flower - -Please consult the regular PyTorch code examples ([quickstart](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch), [advanced](https://github.com/adap/flower/tree/main/examples/advanced-pytorch)) to learn how to use Flower with PyTorch. - -This how-to guide describes the deployment of a long-running Flower server. - -## Preconditions - -Let's assume the following project structure: - -```bash -$ tree . -. -├── client.py -├── server.py -├── task.py -└── requirements.txt -``` - -## Install dependencies - -```bash -pip install -r requirements.txt -``` - -## Start the SuperLink - -```bash -flower-superlink --insecure -``` - -## Start the long-running Flower client - -In a new terminal window, start the first long-running Flower client: - -```bash -flower-client client:app --insecure -``` - -In yet another new terminal window, start the second long-running Flower client: - -```bash -flower-client client:app --insecure -``` - -## Start the driver - -```bash -python driver.py -``` diff --git a/examples/mt-pytorch/driver.py b/examples/mt-pytorch/driver.py deleted file mode 100644 index 06091c954cef..000000000000 --- a/examples/mt-pytorch/driver.py +++ /dev/null @@ -1,226 +0,0 @@ -from typing import List, Tuple -import random -import time - -from flwr.server.driver import GrpcDriver -from flwr.common import ( - ServerMessage, - FitIns, - ndarrays_to_parameters, - serde, - parameters_to_ndarrays, - ClientMessage, - NDArrays, - Code, -) -from flwr.proto import driver_pb2, task_pb2, node_pb2, transport_pb2 -from flwr.server.strategy.aggregate import aggregate -from flwr.common import Metrics -from flwr.server import History -from flwr.common import serde -from task import Net, get_parameters, set_parameters - - -# Define metric aggregation function -def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: - examples = [num_examples for num_examples, _ in metrics] - - # Multiply accuracy of each client by number of examples used - train_losses = [num_examples * m["train_loss"] for num_examples, m in metrics] - train_accuracies = [ - num_examples * m["train_accuracy"] for num_examples, m in metrics - ] - val_losses = [num_examples * m["val_loss"] for num_examples, m in metrics] - val_accuracies = [num_examples * m["val_accuracy"] for num_examples, m in metrics] - - # Aggregate and return custom metric (weighted average) - return { - "train_loss": sum(train_losses) / sum(examples), - "train_accuracy": sum(train_accuracies) / sum(examples), - "val_loss": sum(val_losses) / sum(examples), - "val_accuracy": sum(val_accuracies) / sum(examples), - } - - -# -------------------------------------------------------------------------- Driver SDK -driver = GrpcDriver(driver_service_address="0.0.0.0:9091", root_certificates=None) -# -------------------------------------------------------------------------- Driver SDK - -anonymous_client_nodes = False -num_client_nodes_per_round = 2 -sleep_time = 1 -num_rounds = 3 -parameters = ndarrays_to_parameters(get_parameters(net=Net())) - -# -------------------------------------------------------------------------- Driver SDK -driver.connect() -create_run_res: driver_pb2.CreateRunResponse = driver.create_run( - req=driver_pb2.CreateRunRequest() -) -# -------------------------------------------------------------------------- Driver SDK - -run_id = create_run_res.run_id -print(f"Created run id {run_id}") - -history = History() -for server_round in range(num_rounds): - print(f"Commencing server round {server_round + 1}") - - # List of sampled node IDs in this round - sampled_nodes: List[node_pb2.Node] = [] - - # Sample node ids - if anonymous_client_nodes: - # If we're working with anonymous clients, we don't know their identities, and - # we don't know how many of them we have. We, therefore, have to assume that - # enough anonymous client nodes are available or become available over time. - # - # To schedule a TaskIns for an anonymous client node, we set the node_id to 0 - # (and `anonymous` to True) - # Here, we create an array with only zeros in it: - sampled_node_ids: List[int] = [0] * num_client_nodes_per_round - sampled_nodes = [ - node_pb2.Node(node_id=node_id, anonymous=False) - for node_id in sampled_node_ids - ] - else: - # If our client nodes have identiy (i.e., they are not anonymous), we can get - # those IDs from the Driver API using `get_nodes`. If enough clients are - # available via the Driver API, we can select a subset by taking a random - # sample. - # - # The Driver API might not immediately return enough client node IDs, so we - # loop and wait until enough client nodes are available. - while True: - # Get a list of node ID's from the server - get_nodes_req = driver_pb2.GetNodesRequest(run_id=run_id) - - # ---------------------------------------------------------------------- Driver SDK - get_nodes_res: driver_pb2.GetNodesResponse = driver.get_nodes( - req=get_nodes_req - ) - # ---------------------------------------------------------------------- Driver SDK - - all_nodes: List[node_pb2.Node] = get_nodes_res.nodes - print(f"Got {len(all_nodes)} client nodes") - - if len(all_nodes) >= num_client_nodes_per_round: - # Sample client nodes - sampled_nodes = random.sample(all_nodes, num_client_nodes_per_round) - break - - time.sleep(3) - - # Log sampled node IDs - print(f"Sampled {len(sampled_nodes)} node IDs: {sampled_nodes}") - time.sleep(sleep_time) - - # Schedule a task for all sampled nodes - fit_ins: FitIns = FitIns(parameters=parameters, config={}) - server_message_proto: transport_pb2.ServerMessage = serde.server_message_to_proto( - server_message=ServerMessage(fit_ins=fit_ins) - ) - task_ins_list: List[task_pb2.TaskIns] = [] - for sampled_node in sampled_nodes: - new_task_ins = task_pb2.TaskIns( - task_id="", # Do not set, will be created and set by the DriverAPI - group_id="", - run_id=run_id, - task=task_pb2.Task( - producer=node_pb2.Node( - node_id=0, - anonymous=True, - ), - consumer=sampled_node, - legacy_server_message=server_message_proto, - ), - ) - task_ins_list.append(new_task_ins) - - push_task_ins_req = driver_pb2.PushTaskInsRequest(task_ins_list=task_ins_list) - - # ---------------------------------------------------------------------- Driver SDK - push_task_ins_res: driver_pb2.PushTaskInsResponse = driver.push_task_ins( - req=push_task_ins_req - ) - # ---------------------------------------------------------------------- Driver SDK - - print( - f"Scheduled {len(push_task_ins_res.task_ids)} tasks: {push_task_ins_res.task_ids}" - ) - - time.sleep(sleep_time) - - # Wait for results, ignore empty task_ids - task_ids: List[str] = [ - task_id for task_id in push_task_ins_res.task_ids if task_id != "" - ] - all_task_res: List[task_pb2.TaskRes] = [] - while True: - pull_task_res_req = driver_pb2.PullTaskResRequest( - node=node_pb2.Node(node_id=0, anonymous=True), - task_ids=task_ids, - ) - - # ------------------------------------------------------------------ Driver SDK - pull_task_res_res: driver_pb2.PullTaskResResponse = driver.pull_task_res( - req=pull_task_res_req - ) - # ------------------------------------------------------------------ Driver SDK - - task_res_list: List[task_pb2.TaskRes] = pull_task_res_res.task_res_list - print(f"Got {len(task_res_list)} results") - - time.sleep(sleep_time) - - all_task_res += task_res_list - if len(all_task_res) == len(task_ids): - break - - # Collect correct results - node_messages: List[ClientMessage] = [] - for task_res in all_task_res: - if task_res.task.HasField("legacy_client_message"): - node_messages.append(task_res.task.legacy_client_message) - print(f"Received {len(node_messages)} results") - - weights_results: List[Tuple[NDArrays, int]] = [] - metrics_results: List = [] - for node_message in node_messages: - if not node_message.fit_res: - continue - fit_res = node_message.fit_res - # Aggregate only if the status is OK - if fit_res.status.code != Code.OK.value: - continue - weights_results.append( - (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) - ) - metrics_results.append( - (fit_res.num_examples, serde.metrics_from_proto(fit_res.metrics)) - ) - - # Aggregate parameters (FedAvg) - parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results)) - parameters = parameters_aggregated - - # Aggregate metrics - metrics_aggregated = weighted_average(metrics_results) - history.add_metrics_distributed_fit( - server_round=server_round, metrics=metrics_aggregated - ) - print("Round ", server_round, " metrics: ", metrics_aggregated) - - # Slow down the start of the next round - time.sleep(sleep_time) - -print("app_fit: losses_distributed %s", str(history.losses_distributed)) -print("app_fit: metrics_distributed_fit %s", str(history.metrics_distributed_fit)) -print("app_fit: metrics_distributed %s", str(history.metrics_distributed)) -print("app_fit: losses_centralized %s", str(history.losses_centralized)) -print("app_fit: metrics_centralized %s", str(history.metrics_centralized)) - -# -------------------------------------------------------------------------- Driver SDK -driver.disconnect() -# -------------------------------------------------------------------------- Driver SDK -print("Driver disconnected") diff --git a/examples/mxnet-from-centralized-to-federated/client.py b/examples/mxnet-from-centralized-to-federated/client.py index 3a3a9de146ef..bb666a26508e 100644 --- a/examples/mxnet-from-centralized-to-federated/client.py +++ b/examples/mxnet-from-centralized-to-federated/client.py @@ -1,6 +1,5 @@ """Flower client example using MXNet for MNIST classification.""" - from typing import Dict, List, Tuple import flwr as fl diff --git a/examples/mxnet-from-centralized-to-federated/mxnet_mnist.py b/examples/mxnet-from-centralized-to-federated/mxnet_mnist.py index a53ed018eb48..5cf39da7c9ca 100644 --- a/examples/mxnet-from-centralized-to-federated/mxnet_mnist.py +++ b/examples/mxnet-from-centralized-to-federated/mxnet_mnist.py @@ -5,7 +5,6 @@ https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/image/mnist.html """ - from typing import List, Tuple import mxnet as mx from mxnet import gluon diff --git a/examples/mxnet-from-centralized-to-federated/pyproject.toml b/examples/mxnet-from-centralized-to-federated/pyproject.toml index 952683eb90f6..b00b3ddfe412 100644 --- a/examples/mxnet-from-centralized-to-federated/pyproject.toml +++ b/examples/mxnet-from-centralized-to-federated/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" name = "mxnet_example" version = "0.1.0" description = "MXNet example with MNIST and CNN" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" diff --git a/examples/mxnet-from-centralized-to-federated/server.py b/examples/mxnet-from-centralized-to-federated/server.py index 29cbce1884d1..871aa4e8ec99 100644 --- a/examples/mxnet-from-centralized-to-federated/server.py +++ b/examples/mxnet-from-centralized-to-federated/server.py @@ -1,6 +1,5 @@ """Flower server example.""" - import flwr as fl if __name__ == "__main__": diff --git a/examples/opacus/pyproject.toml b/examples/opacus/pyproject.toml index af0eaf596fbf..26914fa27aa4 100644 --- a/examples/opacus/pyproject.toml +++ b/examples/opacus/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" name = "flwr_opacus" version = "0.1.0" description = "Differentially Private Federated Learning with Opacus and Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" diff --git a/examples/pytorch-federated-variational-autoencoder/pyproject.toml b/examples/pytorch-federated-variational-autoencoder/pyproject.toml index 116140306f62..bc1f85803682 100644 --- a/examples/pytorch-federated-variational-autoencoder/pyproject.toml +++ b/examples/pytorch-federated-variational-autoencoder/pyproject.toml @@ -2,7 +2,7 @@ name = "pytorch_federated_variational_autoencoder" version = "0.1.0" description = "Federated Variational Autoencoder Example" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" diff --git a/examples/pytorch-from-centralized-to-federated/README.md b/examples/pytorch-from-centralized-to-federated/README.md index fccb14158ecd..06ee89dddcac 100644 --- a/examples/pytorch-from-centralized-to-federated/README.md +++ b/examples/pytorch-from-centralized-to-federated/README.md @@ -2,7 +2,7 @@ This example demonstrates how an already existing centralized PyTorch-based machine learning project can be federated with Flower. -This introductory example for Flower uses PyTorch, but you're not required to be a PyTorch expert to run the example. The example will help you to understand how Flower can be used to build federated learning use cases based on existing machine learning projects. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) to download, partition and preprocess the CIFAR-10 dataset. +This introductory example for Flower uses PyTorch, but you're not required to be a PyTorch expert to run the example. The example will help you to understand how Flower can be used to build federated learning use cases based on existing machine learning projects. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) to download, partition and preprocess the CIFAR-10 dataset. ## Project Setup diff --git a/examples/pytorch-from-centralized-to-federated/cifar.py b/examples/pytorch-from-centralized-to-federated/cifar.py index e8f3ec3fd724..277a21da2e70 100644 --- a/examples/pytorch-from-centralized-to-federated/cifar.py +++ b/examples/pytorch-from-centralized-to-federated/cifar.py @@ -51,10 +51,10 @@ def forward(self, x: Tensor) -> Tensor: return x -def load_data(node_id: int): +def load_data(partition_id: int): """Load partition CIFAR10 data.""" fds = FederatedDataset(dataset="cifar10", partitioners={"train": 10}) - partition = fds.load_partition(node_id) + partition = fds.load_partition(partition_id) # Divide data on each node: 80% train, 20% test partition_train_test = partition.train_test_split(test_size=0.2) pytorch_transforms = Compose( diff --git a/examples/pytorch-from-centralized-to-federated/client.py b/examples/pytorch-from-centralized-to-federated/client.py index f89e03bc2053..9df4739e0aab 100644 --- a/examples/pytorch-from-centralized-to-federated/client.py +++ b/examples/pytorch-from-centralized-to-federated/client.py @@ -1,4 +1,5 @@ """Flower client example using PyTorch for CIFAR-10 image classification.""" + import argparse from collections import OrderedDict from typing import Dict, List, Tuple @@ -80,11 +81,11 @@ def evaluate( def main() -> None: """Load data, start CifarClient.""" parser = argparse.ArgumentParser(description="Flower") - parser.add_argument("--node-id", type=int, required=True, choices=range(0, 10)) + parser.add_argument("--partition-id", type=int, required=True, choices=range(0, 10)) args = parser.parse_args() # Load data - trainloader, testloader = cifar.load_data(args.node_id) + trainloader, testloader = cifar.load_data(args.partition_id) # Load model model = cifar.Net().to(DEVICE).train() diff --git a/examples/pytorch-from-centralized-to-federated/pyproject.toml b/examples/pytorch-from-centralized-to-federated/pyproject.toml index 6d6f138a0aea..3d1559e3a515 100644 --- a/examples/pytorch-from-centralized-to-federated/pyproject.toml +++ b/examples/pytorch-from-centralized-to-federated/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.masonry.api" name = "quickstart-pytorch" version = "0.1.0" description = "PyTorch: From Centralized To Federated with Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" diff --git a/examples/pytorch-from-centralized-to-federated/run.sh b/examples/pytorch-from-centralized-to-federated/run.sh index 1ed51dd787ac..6ddf6ad476b4 100755 --- a/examples/pytorch-from-centralized-to-federated/run.sh +++ b/examples/pytorch-from-centralized-to-federated/run.sh @@ -6,7 +6,7 @@ sleep 3 # Sleep for 3s to give the server enough time to start for i in $(seq 0 1); do echo "Starting client $i" - python client.py --node-id $i & + python client.py --partition-id $i & done # This will allow you to use CTRL+C to stop all background processes diff --git a/examples/pytorch-from-centralized-to-federated/server.py b/examples/pytorch-from-centralized-to-federated/server.py index 42f34b3a78e9..5190d690dc20 100644 --- a/examples/pytorch-from-centralized-to-federated/server.py +++ b/examples/pytorch-from-centralized-to-federated/server.py @@ -1,6 +1,5 @@ """Flower server example.""" - from typing import List, Tuple import flwr as fl diff --git a/examples/quickstart-cpp/driver.py b/examples/quickstart-cpp/driver.py index 3b3036f7e928..f19cf0e9bd98 100644 --- a/examples/quickstart-cpp/driver.py +++ b/examples/quickstart-cpp/driver.py @@ -3,7 +3,7 @@ # Start Flower server for three rounds of federated learning if __name__ == "__main__": - fl.server.driver.start_driver( + fl.server.start_driver( server_address="0.0.0.0:9091", config=fl.server.ServerConfig(num_rounds=3), strategy=FedAvgCpp(), diff --git a/examples/quickstart-fastai/pyproject.toml b/examples/quickstart-fastai/pyproject.toml index ffaa97267493..19a25291a6af 100644 --- a/examples/quickstart-fastai/pyproject.toml +++ b/examples/quickstart-fastai/pyproject.toml @@ -6,9 +6,11 @@ build-backend = "poetry.core.masonry.api" name = "quickstart-fastai" version = "0.1.0" description = "Fastai Federated Learning Quickstart with Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] -python = ">=3.8,<3.10" -flwr = "^1.0.0" -fastai = "^2.7.10" +python = ">=3.8,<3.11" +flwr = ">=1.0,<2.0" +fastai = "2.7.14" +torch = "2.2.0" +torchvision = "0.17.0" diff --git a/examples/quickstart-fastai/requirements.txt b/examples/quickstart-fastai/requirements.txt index 0a7f315018a2..9c6e8d77293a 100644 --- a/examples/quickstart-fastai/requirements.txt +++ b/examples/quickstart-fastai/requirements.txt @@ -1,3 +1,4 @@ -fastai~=2.7.12 -flwr~=1.4.0 -torch~=2.0.1 +flwr>=1.0, <2.0 +fastai==2.7.14 +torch==2.2.0 +torchvision==0.17.0 diff --git a/examples/quickstart-huggingface/README.md b/examples/quickstart-huggingface/README.md index fd868aa1fcce..ce7790cd4af5 100644 --- a/examples/quickstart-huggingface/README.md +++ b/examples/quickstart-huggingface/README.md @@ -1,6 +1,6 @@ # Federated HuggingFace Transformers using Flower and PyTorch -This introductory example to using [HuggingFace](https://huggingface.co) Transformers with Flower with PyTorch. This example has been extended from the [quickstart-pytorch](https://flower.dev/docs/examples/quickstart-pytorch.html) example. The training script closely follows the [HuggingFace course](https://huggingface.co/course/chapter3?fw=pt), so you are encouraged to check that out for a detailed explanation of the transformer pipeline. +This introductory example to using [HuggingFace](https://huggingface.co) Transformers with Flower with PyTorch. This example has been extended from the [quickstart-pytorch](https://flower.ai/docs/examples/quickstart-pytorch.html) example. The training script closely follows the [HuggingFace course](https://huggingface.co/course/chapter3?fw=pt), so you are encouraged to check that out for a detailed explanation of the transformer pipeline. Like `quickstart-pytorch`, running this example in itself is also meant to be quite easy. @@ -62,13 +62,13 @@ Now you are ready to start the Flower clients which will participate in the lear Start client 1 in the first terminal: ```shell -python3 client.py --node-id 0 +python3 client.py --partition-id 0 ``` Start client 2 in the second terminal: ```shell -python3 client.py --node-id 1 +python3 client.py --partition-id 1 ``` You will see that PyTorch is starting a federated training. diff --git a/examples/quickstart-huggingface/client.py b/examples/quickstart-huggingface/client.py index 5dc461d30536..9be08d0cbcf4 100644 --- a/examples/quickstart-huggingface/client.py +++ b/examples/quickstart-huggingface/client.py @@ -17,10 +17,10 @@ CHECKPOINT = "distilbert-base-uncased" # transformer model checkpoint -def load_data(node_id): +def load_data(partition_id): """Load IMDB data (training and eval)""" fds = FederatedDataset(dataset="imdb", partitioners={"train": 1_000}) - partition = fds.load_partition(node_id) + partition = fds.load_partition(partition_id) # Divide data: 80% train, 20% test partition_train_test = partition.train_test_split(test_size=0.2) @@ -78,12 +78,12 @@ def test(net, testloader): return loss, accuracy -def main(node_id): +def main(partition_id): net = AutoModelForSequenceClassification.from_pretrained( CHECKPOINT, num_labels=2 ).to(DEVICE) - trainloader, testloader = load_data(node_id) + trainloader, testloader = load_data(partition_id) # Flower client class IMDBClient(fl.client.NumPyClient): @@ -116,12 +116,12 @@ def evaluate(self, parameters, config): if __name__ == "__main__": parser = argparse.ArgumentParser(description="Flower") parser.add_argument( - "--node-id", + "--partition-id", choices=list(range(1_000)), required=True, type=int, help="Partition of the dataset divided into 1,000 iid partitions created " "artificially.", ) - node_id = parser.parse_args().node_id - main(node_id) + partition_id = parser.parse_args().partition_id + main(partition_id) diff --git a/examples/quickstart-huggingface/pyproject.toml b/examples/quickstart-huggingface/pyproject.toml index 50ba0b37f8d2..2b46804d7b45 100644 --- a/examples/quickstart-huggingface/pyproject.toml +++ b/examples/quickstart-huggingface/pyproject.toml @@ -7,8 +7,8 @@ name = "quickstart-huggingface" version = "0.1.0" description = "Hugging Face Transformers Federated Learning Quickstart with Flower" authors = [ - "The Flower Authors ", - "Kaushik Amar Das " + "The Flower Authors ", + "Kaushik Amar Das ", ] [tool.poetry.dependencies] diff --git a/examples/quickstart-huggingface/run.sh b/examples/quickstart-huggingface/run.sh index e722a24a21a9..fa989eab1471 100755 --- a/examples/quickstart-huggingface/run.sh +++ b/examples/quickstart-huggingface/run.sh @@ -6,7 +6,7 @@ sleep 3 # Sleep for 3s to give the server enough time to start for i in `seq 0 1`; do echo "Starting client $i" - python client.py --node-id ${i}& + python client.py --partition-id ${i}& done # This will allow you to use CTRL+C to stop all background processes diff --git a/examples/quickstart-jax/client.py b/examples/quickstart-jax/client.py index afd6f197bcde..2257a3d6daa3 100644 --- a/examples/quickstart-jax/client.py +++ b/examples/quickstart-jax/client.py @@ -1,6 +1,5 @@ """Flower client example using JAX for linear regression.""" - from typing import Dict, List, Tuple, Callable import flwr as fl diff --git a/examples/quickstart-jax/jax_training.py b/examples/quickstart-jax/jax_training.py index 2b523a08516e..a2e23a0927bc 100644 --- a/examples/quickstart-jax/jax_training.py +++ b/examples/quickstart-jax/jax_training.py @@ -7,7 +7,6 @@ please read the JAX documentation or the mentioned tutorial. """ - from typing import Dict, List, Tuple, Callable import jax import jax.numpy as jnp diff --git a/examples/quickstart-jax/pyproject.toml b/examples/quickstart-jax/pyproject.toml index 41b4462d0a14..c956191369b5 100644 --- a/examples/quickstart-jax/pyproject.toml +++ b/examples/quickstart-jax/pyproject.toml @@ -2,7 +2,7 @@ name = "jax_example" version = "0.1.0" description = "JAX example training a linear regression model with federated learning" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" diff --git a/examples/quickstart-mlcube/pyproject.toml b/examples/quickstart-mlcube/pyproject.toml index 0d42fc3b2898..a2862bd5ebb7 100644 --- a/examples/quickstart-mlcube/pyproject.toml +++ b/examples/quickstart-mlcube/pyproject.toml @@ -6,13 +6,13 @@ build-backend = "poetry.masonry.api" name = "quickstart-ml-cube" version = "0.1.0" description = "Keras Federated Learning Quickstart with Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" flwr = ">=1.0,<2.0" # For development: { path = "../../", develop = true } -tensorflow-cpu = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\""} -tensorflow-macos = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\""} +tensorflow-cpu = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\"" } +tensorflow-macos = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\"" } mlcube = "0.0.9" mlcube-docker = "0.0.9" tensorflow-estimator = ">=2.9.1,<2.11.1 || >2.11.1" diff --git a/examples/quickstart-mlx/README.md b/examples/quickstart-mlx/README.md index d94a87a014f7..cca55bcb946a 100644 --- a/examples/quickstart-mlx/README.md +++ b/examples/quickstart-mlx/README.md @@ -66,19 +66,19 @@ following commands. Start a first client in the first terminal: ```shell -python3 client.py --node-id 0 +python3 client.py --partition-id 0 ``` And another one in the second terminal: ```shell -python3 client.py --node-id 1 +python3 client.py --partition-id 1 ``` If you want to utilize your GPU, you can use the `--gpu` argument: ```shell -python3 client.py --gpu --node-id 2 +python3 client.py --gpu --partition-id 2 ``` Note that you can start many more clients if you want, but each will have to be in its own terminal. @@ -96,7 +96,7 @@ We will use `flwr_datasets` to easily download and partition the `MNIST` dataset ```python fds = FederatedDataset(dataset="mnist", partitioners={"train": 3}) -partition = fds.load_partition(node_id = args.node_id) +partition = fds.load_partition(partition_id = args.partition_id) partition_splits = partition.train_test_split(test_size=0.2) partition_splits['train'].set_format("numpy") diff --git a/examples/quickstart-mlx/client.py b/examples/quickstart-mlx/client.py index 3b506399a5f1..faba2b94d6bd 100644 --- a/examples/quickstart-mlx/client.py +++ b/examples/quickstart-mlx/client.py @@ -89,7 +89,7 @@ def evaluate(self, parameters, config): parser = argparse.ArgumentParser("Train a simple MLP on MNIST with MLX.") parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.") parser.add_argument( - "--node-id", + "--partition-id", choices=[0, 1, 2], type=int, help="Partition of the dataset divided into 3 iid partitions created artificially.", @@ -106,7 +106,7 @@ def evaluate(self, parameters, config): learning_rate = 1e-1 fds = FederatedDataset(dataset="mnist", partitioners={"train": 3}) - partition = fds.load_partition(node_id=args.node_id) + partition = fds.load_partition(partition_id=args.partition_id) partition_splits = partition.train_test_split(test_size=0.2) partition_splits["train"].set_format("numpy") diff --git a/examples/quickstart-mlx/pyproject.toml b/examples/quickstart-mlx/pyproject.toml index deb541c5ba9c..752040b6aaa9 100644 --- a/examples/quickstart-mlx/pyproject.toml +++ b/examples/quickstart-mlx/pyproject.toml @@ -6,11 +6,11 @@ build-backend = "poetry.core.masonry.api" name = "quickstart-mlx" version = "0.1.0" description = "MLX Federated Learning Quickstart with Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" flwr = ">=1.0,<2.0" mlx = "==0.0.3" numpy = "==1.24.4" -flwr-datasets = {extras = ["vision"], version = "^0.0.2"} +flwr-datasets = { extras = ["vision"], version = "^0.0.2" } diff --git a/examples/quickstart-mlx/run.sh b/examples/quickstart-mlx/run.sh index 70281049517d..40d211848c07 100755 --- a/examples/quickstart-mlx/run.sh +++ b/examples/quickstart-mlx/run.sh @@ -8,7 +8,7 @@ sleep 3 # Sleep for 3s to give the server enough time to start for i in $(seq 0 1); do echo "Starting client $i" - python client.py --node-id $i & + python client.py --partition-id $i & done # Enable CTRL+C to stop all background processes diff --git a/examples/quickstart-mxnet/client.py b/examples/quickstart-mxnet/client.py index b0c937f7350f..6c2b2e99775d 100644 --- a/examples/quickstart-mxnet/client.py +++ b/examples/quickstart-mxnet/client.py @@ -5,7 +5,6 @@ https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/image/mnist.html """ - import flwr as fl import numpy as np import mxnet as mx diff --git a/examples/quickstart-mxnet/pyproject.toml b/examples/quickstart-mxnet/pyproject.toml index 952683eb90f6..b00b3ddfe412 100644 --- a/examples/quickstart-mxnet/pyproject.toml +++ b/examples/quickstart-mxnet/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" name = "mxnet_example" version = "0.1.0" description = "MXNet example with MNIST and CNN" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" diff --git a/examples/quickstart-mxnet/server.py b/examples/quickstart-mxnet/server.py index 29cbce1884d1..871aa4e8ec99 100644 --- a/examples/quickstart-mxnet/server.py +++ b/examples/quickstart-mxnet/server.py @@ -1,6 +1,5 @@ """Flower server example.""" - import flwr as fl if __name__ == "__main__": diff --git a/examples/quickstart-pandas/README.md b/examples/quickstart-pandas/README.md index a25e6ea6ee36..dd69f3ead3cb 100644 --- a/examples/quickstart-pandas/README.md +++ b/examples/quickstart-pandas/README.md @@ -1,6 +1,6 @@ # Flower Example using Pandas -This introductory example to Flower uses Pandas, but deep knowledge of Pandas is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) to +This introductory example to Flower uses Pandas, but deep knowledge of Pandas is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) to download, partition and preprocess the dataset. Running this example in itself is quite easy. @@ -70,13 +70,13 @@ Now you are ready to start the Flower clients which will participate in the lear Start client 1 in the first terminal: ```shell -$ python3 client.py --node-id 0 +$ python3 client.py --partition-id 0 ``` Start client 2 in the second terminal: ```shell -$ python3 client.py --node-id 1 +$ python3 client.py --partition-id 1 ``` -You will see that the server is printing aggregated statistics about the dataset distributed amongst clients. Have a look to the [Flower Quickstarter documentation](https://flower.dev/docs/quickstart-pandas.html) for a detailed explanation. +You will see that the server is printing aggregated statistics about the dataset distributed amongst clients. Have a look to the [Flower Quickstarter documentation](https://flower.ai/docs/quickstart-pandas.html) for a detailed explanation. diff --git a/examples/quickstart-pandas/client.py b/examples/quickstart-pandas/client.py index 8585922e4572..c52b7c65b04c 100644 --- a/examples/quickstart-pandas/client.py +++ b/examples/quickstart-pandas/client.py @@ -42,14 +42,14 @@ def fit( parser = argparse.ArgumentParser(description="Flower") parser.add_argument( - "--node-id", + "--partition-id", type=int, choices=range(0, N_CLIENTS), required=True, - help="Specifies the node id of artificially partitioned datasets.", + help="Specifies the partition id of artificially partitioned datasets.", ) args = parser.parse_args() - partition_id = args.node_id + partition_id = args.partition_id # Load the partition data fds = FederatedDataset(dataset="hitorilabs/iris", partitioners={"train": N_CLIENTS}) diff --git a/examples/quickstart-pandas/pyproject.toml b/examples/quickstart-pandas/pyproject.toml index 6229210d6488..2e6b1424bb54 100644 --- a/examples/quickstart-pandas/pyproject.toml +++ b/examples/quickstart-pandas/pyproject.toml @@ -7,7 +7,7 @@ name = "quickstart-pandas" version = "0.1.0" description = "Pandas Federated Analytics Quickstart with Flower" authors = ["Ragy Haddad "] -maintainers = ["The Flower Authors "] +maintainers = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" diff --git a/examples/quickstart-pandas/run.sh b/examples/quickstart-pandas/run.sh index 571fa8bfb3e4..2ae1e582b8cf 100755 --- a/examples/quickstart-pandas/run.sh +++ b/examples/quickstart-pandas/run.sh @@ -4,7 +4,7 @@ sleep 3 # Sleep for 3s to give the server enough time to start for i in `seq 0 1`; do echo "Starting client $i" - python client.py --node-id ${i} & + python client.py --partition-id ${i} & done # This will allow you to use CTRL+C to stop all background processes diff --git a/examples/quickstart-pytorch-lightning/README.md b/examples/quickstart-pytorch-lightning/README.md index 1287b50bca65..fb29c7e9e9ea 100644 --- a/examples/quickstart-pytorch-lightning/README.md +++ b/examples/quickstart-pytorch-lightning/README.md @@ -1,6 +1,6 @@ # Flower Example using PyTorch Lightning -This introductory example to Flower uses PyTorch, but deep knowledge of PyTorch Lightning is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. Running this example in itself is quite easy. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) to download, partition and preprocess the MNIST dataset. +This introductory example to Flower uses PyTorch, but deep knowledge of PyTorch Lightning is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. Running this example in itself is quite easy. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) to download, partition and preprocess the MNIST dataset. ## Project Setup @@ -57,20 +57,20 @@ Afterwards you are ready to start the Flower server as well as the clients. You python server.py ``` -Now you are ready to start the Flower clients which will participate in the learning. We need to specify the node id to +Now you are ready to start the Flower clients which will participate in the learning. We need to specify the partition id to use different partitions of the data on different nodes. To do so simply open two more terminal windows and run the following commands. Start client 1 in the first terminal: ```shell -python client.py --node-id 0 +python client.py --partition-id 0 ``` Start client 2 in the second terminal: ```shell -python client.py --node-id 1 +python client.py --partition-id 1 ``` You will see that PyTorch is starting a federated training. Look at the [code](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch) for a detailed explanation. diff --git a/examples/quickstart-pytorch-lightning/client.py b/examples/quickstart-pytorch-lightning/client.py index fc5f1ee03cfe..6e21259cc492 100644 --- a/examples/quickstart-pytorch-lightning/client.py +++ b/examples/quickstart-pytorch-lightning/client.py @@ -58,18 +58,18 @@ def _set_parameters(model, parameters): def main() -> None: parser = argparse.ArgumentParser(description="Flower") parser.add_argument( - "--node-id", + "--partition-id", type=int, choices=range(0, 10), required=True, help="Specifies the artificial data partition", ) args = parser.parse_args() - node_id = args.node_id + partition_id = args.partition_id # Model and data model = mnist.LitAutoEncoder() - train_loader, val_loader, test_loader = mnist.load_data(node_id) + train_loader, val_loader, test_loader = mnist.load_data(partition_id) # Flower client client = FlowerClient(model, train_loader, val_loader, test_loader).to_client() diff --git a/examples/quickstart-pytorch-lightning/pyproject.toml b/examples/quickstart-pytorch-lightning/pyproject.toml index 853ef9c1646f..a09aaa3d65b5 100644 --- a/examples/quickstart-pytorch-lightning/pyproject.toml +++ b/examples/quickstart-pytorch-lightning/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.masonry.api" name = "quickstart-pytorch-lightning" version = "0.1.0" description = "Federated Learning Quickstart with Flower and PyTorch Lightning" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = "^3.8" diff --git a/examples/quickstart-pytorch-lightning/run.sh b/examples/quickstart-pytorch-lightning/run.sh index 60893a9a055b..62a1dac199bd 100755 --- a/examples/quickstart-pytorch-lightning/run.sh +++ b/examples/quickstart-pytorch-lightning/run.sh @@ -6,7 +6,7 @@ sleep 3 # Sleep for 3s to give the server enough time to start for i in $(seq 0 1); do echo "Starting client $i" - python client.py --node-id "${i}" & + python client.py --partition-id "${i}" & done # This will allow you to use CTRL+C to stop all background processes diff --git a/examples/quickstart-pytorch/README.md b/examples/quickstart-pytorch/README.md index 6de0dcf7ab32..02c9b4b38498 100644 --- a/examples/quickstart-pytorch/README.md +++ b/examples/quickstart-pytorch/README.md @@ -1,6 +1,6 @@ # Flower Example using PyTorch -This introductory example to Flower uses PyTorch, but deep knowledge of PyTorch is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. Running this example in itself is quite easy. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) to download, partition and preprocess the CIFAR-10 dataset. +This introductory example to Flower uses PyTorch, but deep knowledge of PyTorch is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. Running this example in itself is quite easy. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) to download, partition and preprocess the CIFAR-10 dataset. ## Project Setup @@ -55,20 +55,20 @@ Afterwards you are ready to start the Flower server as well as the clients. You python3 server.py ``` -Now you are ready to start the Flower clients which will participate in the learning. We need to specify the node id to +Now you are ready to start the Flower clients which will participate in the learning. We need to specify the partition id to use different partitions of the data on different nodes. To do so simply open two more terminal windows and run the following commands. Start client 1 in the first terminal: ```shell -python3 client.py --node-id 0 +python3 client.py --partition-id 0 ``` Start client 2 in the second terminal: ```shell -python3 client.py --node-id 1 +python3 client.py --partition-id 1 ``` You will see that PyTorch is starting a federated training. Look at the [code](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch) for a detailed explanation. diff --git a/examples/quickstart-pytorch/client.py b/examples/quickstart-pytorch/client.py index b5ea4c94dd21..e640ce111dff 100644 --- a/examples/quickstart-pytorch/client.py +++ b/examples/quickstart-pytorch/client.py @@ -69,10 +69,10 @@ def test(net, testloader): return loss, accuracy -def load_data(node_id): +def load_data(partition_id): """Load partition CIFAR10 data.""" fds = FederatedDataset(dataset="cifar10", partitioners={"train": 3}) - partition = fds.load_partition(node_id) + partition = fds.load_partition(partition_id) # Divide data on each node: 80% train, 20% test partition_train_test = partition.train_test_split(test_size=0.2) pytorch_transforms = Compose( @@ -94,20 +94,20 @@ def apply_transforms(batch): # 2. Federation of the pipeline with Flower # ############################################################################# -# Get node id +# Get partition id parser = argparse.ArgumentParser(description="Flower") parser.add_argument( - "--node-id", + "--partition-id", choices=[0, 1, 2], required=True, type=int, help="Partition of the dataset divided into 3 iid partitions created artificially.", ) -node_id = parser.parse_args().node_id +partition_id = parser.parse_args().partition_id # Load model and data (simple CNN, CIFAR-10) net = Net().to(DEVICE) -trainloader, testloader = load_data(node_id=node_id) +trainloader, testloader = load_data(partition_id=partition_id) # Define Flower client diff --git a/examples/quickstart-pytorch/pyproject.toml b/examples/quickstart-pytorch/pyproject.toml index ec6a3af8c5b4..d8e1503dd8a7 100644 --- a/examples/quickstart-pytorch/pyproject.toml +++ b/examples/quickstart-pytorch/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" name = "quickstart-pytorch" version = "0.1.0" description = "PyTorch Federated Learning Quickstart with Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" diff --git a/examples/quickstart-pytorch/run.sh b/examples/quickstart-pytorch/run.sh index cdace99bb8df..6ca9c8cafec9 100755 --- a/examples/quickstart-pytorch/run.sh +++ b/examples/quickstart-pytorch/run.sh @@ -8,7 +8,7 @@ sleep 3 # Sleep for 3s to give the server enough time to start for i in $(seq 0 1); do echo "Starting client $i" - python client.py --node-id "$i" & + python client.py --partition-id "$i" & done # Enable CTRL+C to stop all background processes diff --git a/examples/quickstart-sklearn-tabular/README.md b/examples/quickstart-sklearn-tabular/README.md index d62525c96c18..a975a9392800 100644 --- a/examples/quickstart-sklearn-tabular/README.md +++ b/examples/quickstart-sklearn-tabular/README.md @@ -3,7 +3,7 @@ This example of Flower uses `scikit-learn`'s `LogisticRegression` model to train a federated learning system on "iris" (tabular) dataset. It will help you understand how to adapt Flower for use with `scikit-learn`. -Running this example in itself is quite easy. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) to +Running this example in itself is quite easy. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) to download, partition and preprocess the dataset. ## Project Setup @@ -63,15 +63,15 @@ poetry run python3 server.py Now you are ready to start the Flower clients which will participate in the learning. To do so simply open two more terminals and run the following command in each: ```shell -poetry run python3 client.py --node-id 0 # node-id should be any of {0,1,2} +poetry run python3 client.py --partition-id 0 # partition-id should be any of {0,1,2} ``` Alternatively you can run all of it in one shell as follows: ```shell poetry run python3 server.py & -poetry run python3 client.py --node-id 0 & -poetry run python3 client.py --node-id 1 +poetry run python3 client.py --partition-id 0 & +poetry run python3 client.py --partition-id 1 ``` You will see that Flower is starting a federated training. diff --git a/examples/quickstart-sklearn-tabular/client.py b/examples/quickstart-sklearn-tabular/client.py index 5dc0e88b3c75..fcab8f5d5612 100644 --- a/examples/quickstart-sklearn-tabular/client.py +++ b/examples/quickstart-sklearn-tabular/client.py @@ -13,14 +13,14 @@ parser = argparse.ArgumentParser(description="Flower") parser.add_argument( - "--node-id", + "--partition-id", type=int, choices=range(0, N_CLIENTS), required=True, help="Specifies the artificial data partition", ) args = parser.parse_args() - partition_id = args.node_id + partition_id = args.partition_id # Load the partition data fds = FederatedDataset(dataset="hitorilabs/iris", partitioners={"train": N_CLIENTS}) diff --git a/examples/quickstart-sklearn-tabular/pyproject.toml b/examples/quickstart-sklearn-tabular/pyproject.toml index 34a78048d3b0..86eab5c38df0 100644 --- a/examples/quickstart-sklearn-tabular/pyproject.toml +++ b/examples/quickstart-sklearn-tabular/pyproject.toml @@ -7,8 +7,8 @@ name = "sklearn-mnist" version = "0.1.0" description = "Federated learning with scikit-learn and Flower" authors = [ - "The Flower Authors ", - "Kaushik Amar Das " + "The Flower Authors ", + "Kaushik Amar Das ", ] [tool.poetry.dependencies] diff --git a/examples/quickstart-sklearn-tabular/run.sh b/examples/quickstart-sklearn-tabular/run.sh index 48cee1b41b74..f770ca05f8f4 100755 --- a/examples/quickstart-sklearn-tabular/run.sh +++ b/examples/quickstart-sklearn-tabular/run.sh @@ -8,7 +8,7 @@ sleep 3 # Sleep for 3s to give the server enough time to start for i in $(seq 0 1); do echo "Starting client $i" - python client.py --node-id "${i}" & + python client.py --partition-id "${i}" & done # This will allow you to use CTRL+C to stop all background processes diff --git a/examples/quickstart-tabnet/pyproject.toml b/examples/quickstart-tabnet/pyproject.toml index 948eaf085b86..18f1979791bd 100644 --- a/examples/quickstart-tabnet/pyproject.toml +++ b/examples/quickstart-tabnet/pyproject.toml @@ -6,12 +6,12 @@ build-backend = "poetry.core.masonry.api" name = "quickstart-tabnet" version = "0.1.0" description = "Tabnet Federated Learning Quickstart with Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" flwr = ">=1.0,<2.0" -tensorflow-cpu = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\""} -tensorflow-macos = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\""} +tensorflow-cpu = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\"" } +tensorflow-macos = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\"" } tensorflow_datasets = "4.8.3" tabnet = "0.1.6" diff --git a/examples/quickstart-tensorflow/README.md b/examples/quickstart-tensorflow/README.md index 92d38c9340d7..8d5e9434b086 100644 --- a/examples/quickstart-tensorflow/README.md +++ b/examples/quickstart-tensorflow/README.md @@ -1,7 +1,7 @@ # Flower Example using TensorFlow/Keras This introductory example to Flower uses Keras but deep knowledge of Keras is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. -Running this example in itself is quite easy. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) to download, partition and preprocess the CIFAR-10 dataset. +Running this example in itself is quite easy. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) to download, partition and preprocess the CIFAR-10 dataset. ## Project Setup diff --git a/examples/quickstart-tensorflow/client.py b/examples/quickstart-tensorflow/client.py index 37abbbcc46ec..3e2035c09311 100644 --- a/examples/quickstart-tensorflow/client.py +++ b/examples/quickstart-tensorflow/client.py @@ -11,7 +11,7 @@ # Parse arguments parser = argparse.ArgumentParser(description="Flower") parser.add_argument( - "--node-id", + "--partition-id", type=int, choices=[0, 1, 2], required=True, @@ -26,7 +26,7 @@ # Download and partition dataset fds = FederatedDataset(dataset="cifar10", partitioners={"train": 3}) -partition = fds.load_partition(args.node_id, "train") +partition = fds.load_partition(args.partition_id, "train") partition.set_format("numpy") # Divide data on each node: 80% train, 20% test diff --git a/examples/quickstart-tensorflow/pyproject.toml b/examples/quickstart-tensorflow/pyproject.toml index e027a7353181..98aeb932cab9 100644 --- a/examples/quickstart-tensorflow/pyproject.toml +++ b/examples/quickstart-tensorflow/pyproject.toml @@ -6,11 +6,11 @@ build-backend = "poetry.core.masonry.api" name = "quickstart-tensorflow" version = "0.1.0" description = "Keras Federated Learning Quickstart with Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" flwr = ">=1.0,<2.0" flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" } -tensorflow-cpu = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\""} -tensorflow-macos = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\""} +tensorflow-cpu = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\"" } +tensorflow-macos = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\"" } diff --git a/examples/quickstart-tensorflow/run.sh b/examples/quickstart-tensorflow/run.sh index 439abea8df4b..76188f197e3e 100755 --- a/examples/quickstart-tensorflow/run.sh +++ b/examples/quickstart-tensorflow/run.sh @@ -6,7 +6,7 @@ sleep 3 # Sleep for 3s to give the server enough time to start for i in `seq 0 1`; do echo "Starting client $i" - python client.py --node-id $i & + python client.py --partition-id $i & done # This will allow you to use CTRL+C to stop all background processes diff --git a/examples/secaggplus-mt/pyproject.toml b/examples/secaggplus-mt/pyproject.toml index 94d8defa3316..fe6fc67252b8 100644 --- a/examples/secaggplus-mt/pyproject.toml +++ b/examples/secaggplus-mt/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" name = "secaggplus-mt" version = "0.1.0" description = "Secure Aggregation with Driver API" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" diff --git a/examples/secaggplus-mt/workflows.py b/examples/secaggplus-mt/workflows.py index b98de883b8f7..4079a2568f2f 100644 --- a/examples/secaggplus-mt/workflows.py +++ b/examples/secaggplus-mt/workflows.py @@ -54,13 +54,13 @@ RECORD_KEY_CONFIGS, ) from flwr.common.secure_aggregation.secaggplus_utils import pseudo_rand_gen -from flwr.common.typing import ConfigsRecordValues, FitIns, ServerMessage +from flwr.common.typing import ConfigsRecordValues, FitIns from flwr.proto.task_pb2 import Task from flwr.common import serde -from flwr.common.constant import TASK_TYPE_FIT -from flwr.common.recordset import RecordSet +from flwr.common.constant import MESSAGE_TYPE_FIT +from flwr.common import RecordSet from flwr.common import recordset_compat as compat -from flwr.common.configsrecord import ConfigsRecord +from flwr.common import ConfigsRecord LOG_EXPLAIN = True @@ -79,16 +79,16 @@ def _wrap_in_task( recordset = compat.fitins_to_recordset(fit_ins, keep_input=True) else: recordset = RecordSet() - recordset.set_configs(RECORD_KEY_CONFIGS, ConfigsRecord(named_values)) + recordset.configs_records[RECORD_KEY_CONFIGS] = ConfigsRecord(named_values) return Task( - task_type=TASK_TYPE_FIT, + task_type=MESSAGE_TYPE_FIT, recordset=serde.recordset_to_proto(recordset), ) def _get_from_task(task: Task) -> Dict[str, ConfigsRecordValues]: recordset = serde.recordset_from_proto(task.recordset) - return recordset.get_configs(RECORD_KEY_CONFIGS).data + return recordset.configs_records[RECORD_KEY_CONFIGS] _secure_aggregation_configuration = { diff --git a/examples/simulation-pytorch/README.md b/examples/simulation-pytorch/README.md index 11b7a3364376..5ba5ec70dc3e 100644 --- a/examples/simulation-pytorch/README.md +++ b/examples/simulation-pytorch/README.md @@ -1,6 +1,6 @@ # Flower Simulation example using PyTorch -This introductory example uses the simulation capabilities of Flower to simulate a large number of clients on a single machine. Take a look at the [Documentation](https://flower.dev/docs/framework/how-to-run-simulations.html) for a deep dive into how Flower simulation works. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) to download, partition and preprocess the MNIST dataset. This examples uses 100 clients by default. +This introductory example uses the simulation capabilities of Flower to simulate a large number of clients on a single machine. Take a look at the [Documentation](https://flower.ai/docs/framework/how-to-run-simulations.html) for a deep dive into how Flower simulation works. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) to download, partition and preprocess the MNIST dataset. This examples uses 100 clients by default. ## Running the example (via Jupyter Notebook) @@ -79,4 +79,4 @@ python sim.py --num_cpus=2 python sim.py --num_cpus=2 --num_gpus=0.2 ``` -Take a look at the [Documentation](https://flower.dev/docs/framework/how-to-run-simulations.html) for more details on how you can customise your simulation. +Take a look at the [Documentation](https://flower.ai/docs/framework/how-to-run-simulations.html) for more details on how you can customise your simulation. diff --git a/examples/simulation-pytorch/pyproject.toml b/examples/simulation-pytorch/pyproject.toml index 07918c0cd17c..5978c17f2c60 100644 --- a/examples/simulation-pytorch/pyproject.toml +++ b/examples/simulation-pytorch/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" name = "simulation-pytorch" version = "0.1.0" description = "Federated Learning Simulation with Flower and PyTorch" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" @@ -17,4 +17,3 @@ torchvision = "0.16.1" [tool.poetry.group.dev.dependencies] ipykernel = "^6.27.0" - diff --git a/examples/simulation-pytorch/sim.ipynb b/examples/simulation-pytorch/sim.ipynb index d1e7358566cc..93a79d2f0e0a 100644 --- a/examples/simulation-pytorch/sim.ipynb +++ b/examples/simulation-pytorch/sim.ipynb @@ -30,7 +30,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We will be using the _simulation_ mode in Flower, which allows you to run a large number of clients without the overheads of manually managing devices. This is achieved via the [Virtual Client Engine](https://flower.dev/docs/framework/how-to-run-simulations.html) in Flower. With simulation, you can dynamically scale your experiments whether you run the code on your laptop, a machine with a single GPU, a server with multiple GPUs os even on a cluster with multiple servers. The `Virtual Client Engine` handles everything transparently and it allows you to specify how many resources (e.g. CPU cores, GPU VRAM) should be assigned to each virtual client." + "We will be using the _simulation_ mode in Flower, which allows you to run a large number of clients without the overheads of manually managing devices. This is achieved via the [Virtual Client Engine](https://flower.ai/docs/framework/how-to-run-simulations.html) in Flower. With simulation, you can dynamically scale your experiments whether you run the code on your laptop, a machine with a single GPU, a server with multiple GPUs os even on a cluster with multiple servers. The `Virtual Client Engine` handles everything transparently and it allows you to specify how many resources (e.g. CPU cores, GPU VRAM) should be assigned to each virtual client." ] }, { @@ -178,7 +178,7 @@ "\n", "To start designing a Federated Learning pipeline we need to meet one of the key properties in FL: each client has its own data partition. To accomplish this with the MNIST dataset, we are going to generate N random partitions, where N is the total number of clients in our FL system.\n", "\n", - "We can use [Flower Datasets](https://flower.dev/docs/datasets/) to effortlessly obtain an off-the-shelf partitioned dataset or partition one that isn't pre-partitioned. Let's choose MNIST." + "We can use [Flower Datasets](https://flower.ai/docs/datasets/) to effortlessly obtain an off-the-shelf partitioned dataset or partition one that isn't pre-partitioned. Let's choose MNIST." ] }, { @@ -605,11 +605,11 @@ "\n", "Get all resources you need!\n", "\n", - "* **[DOCS]** Our complete documenation: https://flower.dev/docs/\n", - "* **[Examples]** All Flower examples: https://flower.dev/docs/examples/\n", + "* **[DOCS]** Our complete documenation: https://flower.ai/docs/\n", + "* **[Examples]** All Flower examples: https://flower.ai/docs/examples/\n", "* **[VIDEO]** Our Youtube channel: https://www.youtube.com/@flowerlabs\n", "\n", - "Don't forget to join our Slack channel: https://flower.dev/join-slack/\n" + "Don't forget to join our Slack channel: https://flower.ai/join-slack/\n" ] } ], diff --git a/examples/simulation-tensorflow/README.md b/examples/simulation-tensorflow/README.md index f0d94f343d37..75be823db2eb 100644 --- a/examples/simulation-tensorflow/README.md +++ b/examples/simulation-tensorflow/README.md @@ -1,6 +1,6 @@ # Flower Simulation example using TensorFlow/Keras -This introductory example uses the simulation capabilities of Flower to simulate a large number of clients on a single machine. Take a look at the [Documentation](https://flower.dev/docs/framework/how-to-run-simulations.html) for a deep dive into how Flower simulation works. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) to download, partition and preprocess the MNIST dataset. This examples uses 100 clients by default. +This introductory example uses the simulation capabilities of Flower to simulate a large number of clients on a single machine. Take a look at the [Documentation](https://flower.ai/docs/framework/how-to-run-simulations.html) for a deep dive into how Flower simulation works. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) to download, partition and preprocess the MNIST dataset. This examples uses 100 clients by default. ## Running the example (via Jupyter Notebook) @@ -78,4 +78,4 @@ python sim.py --num_cpus=2 python sim.py --num_cpus=2 --num_gpus=0.2 ``` -Take a look at the [Documentation](https://flower.dev/docs/framework/how-to-run-simulations.html) for more details on how you can customise your simulation. +Take a look at the [Documentation](https://flower.ai/docs/framework/how-to-run-simulations.html) for more details on how you can customise your simulation. diff --git a/examples/simulation-tensorflow/pyproject.toml b/examples/simulation-tensorflow/pyproject.toml index f2e7bd3006c0..ad8cc2032b2d 100644 --- a/examples/simulation-tensorflow/pyproject.toml +++ b/examples/simulation-tensorflow/pyproject.toml @@ -6,11 +6,11 @@ build-backend = "poetry.core.masonry.api" name = "simulation-tensorflow" version = "0.1.0" description = "Federated Learning Simulation with Flower and Tensorflow" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" flwr = { extras = ["simulation"], version = ">=1.0,<2.0" } flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" } -tensorflow = {version = "^2.9.1, !=2.11.1", markers="platform_machine == 'x86_64'"} -tensorflow-macos = {version = "^2.9.1, !=2.11.1", markers="sys_platform == 'darwin' and platform_machine == 'arm64'"} +tensorflow = { version = "^2.9.1, !=2.11.1", markers = "platform_machine == 'x86_64'" } +tensorflow-macos = { version = "^2.9.1, !=2.11.1", markers = "sys_platform == 'darwin' and platform_machine == 'arm64'" } diff --git a/examples/simulation-tensorflow/sim.ipynb b/examples/simulation-tensorflow/sim.ipynb index 5ef1992bcc7e..9acfba99237c 100644 --- a/examples/simulation-tensorflow/sim.ipynb +++ b/examples/simulation-tensorflow/sim.ipynb @@ -232,7 +232,7 @@ "\n", "Flower comes with a number of built-in strategies, but we can also use our own strategy implementations to customize nearly all aspects of the federated learning approach. For this example, we use the built-in `FedAvg` implementation and customize it using a few basic parameters. The last step is the actual call to `start_simulation` which - you guessed it - actually starts the simulation.\n", "\n", - "We can use [Flower Datasets](https://flower.dev/docs/datasets/) to effortlessly obtain an off-the-shelf partitioned dataset or partition one that isn't pre-partitioned. Let's choose MNIST." + "We can use [Flower Datasets](https://flower.ai/docs/datasets/) to effortlessly obtain an off-the-shelf partitioned dataset or partition one that isn't pre-partitioned. Let's choose MNIST." ] }, { @@ -323,11 +323,11 @@ "\n", "Get all resources you need!\n", "\n", - "* **[DOCS]** Our complete documenation: https://flower.dev/docs/\n", - "* **[Examples]** All Flower examples: https://flower.dev/docs/examples/\n", + "* **[DOCS]** Our complete documenation: https://flower.ai/docs/\n", + "* **[Examples]** All Flower examples: https://flower.ai/docs/examples/\n", "* **[VIDEO]** Our Youtube channel: https://www.youtube.com/@flowerlabs\n", "\n", - "Don't forget to join our Slack channel: https://flower.dev/join-slack/" + "Don't forget to join our Slack channel: https://flower.ai/join-slack/" ] } ], diff --git a/examples/sklearn-logreg-mnist/README.md b/examples/sklearn-logreg-mnist/README.md index ee3cdfc9768e..12b1a5e3bc1a 100644 --- a/examples/sklearn-logreg-mnist/README.md +++ b/examples/sklearn-logreg-mnist/README.md @@ -1,7 +1,7 @@ # Flower Example using scikit-learn This example of Flower uses `scikit-learn`'s `LogisticRegression` model to train a federated learning system. It will help you understand how to adapt Flower for use with `scikit-learn`. -Running this example in itself is quite easy. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) to download, partition and preprocess the MNIST dataset. +Running this example in itself is quite easy. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) to download, partition and preprocess the MNIST dataset. ## Project Setup @@ -62,13 +62,13 @@ Now you are ready to start the Flower clients which will participate in the lear Start client 1 in the first terminal: ```shell -python3 client.py --node-id 0 # or any integer in {0-9} +python3 client.py --partition-id 0 # or any integer in {0-9} ``` Start client 2 in the second terminal: ```shell -python3 client.py --node-id 1 # or any integer in {0-9} +python3 client.py --partition-id 1 # or any integer in {0-9} ``` Alternatively, you can run all of it in one shell as follows: diff --git a/examples/sklearn-logreg-mnist/client.py b/examples/sklearn-logreg-mnist/client.py index 3d41cb6fbb21..1e9349df1acc 100644 --- a/examples/sklearn-logreg-mnist/client.py +++ b/examples/sklearn-logreg-mnist/client.py @@ -13,14 +13,14 @@ parser = argparse.ArgumentParser(description="Flower") parser.add_argument( - "--node-id", + "--partition-id", type=int, choices=range(0, N_CLIENTS), required=True, help="Specifies the artificial data partition", ) args = parser.parse_args() - partition_id = args.node_id + partition_id = args.partition_id # Load the partition data fds = FederatedDataset(dataset="mnist", partitioners={"train": N_CLIENTS}) diff --git a/examples/sklearn-logreg-mnist/pyproject.toml b/examples/sklearn-logreg-mnist/pyproject.toml index 8ea49fe187a2..58cc5ca4a02e 100644 --- a/examples/sklearn-logreg-mnist/pyproject.toml +++ b/examples/sklearn-logreg-mnist/pyproject.toml @@ -7,8 +7,8 @@ name = "sklearn-mnist" version = "0.1.0" description = "Federated learning with scikit-learn and Flower" authors = [ - "The Flower Authors ", - "Kaushik Amar Das " + "The Flower Authors ", + "Kaushik Amar Das ", ] [tool.poetry.dependencies] diff --git a/examples/sklearn-logreg-mnist/run.sh b/examples/sklearn-logreg-mnist/run.sh index 48cee1b41b74..f770ca05f8f4 100755 --- a/examples/sklearn-logreg-mnist/run.sh +++ b/examples/sklearn-logreg-mnist/run.sh @@ -8,7 +8,7 @@ sleep 3 # Sleep for 3s to give the server enough time to start for i in $(seq 0 1); do echo "Starting client $i" - python client.py --node-id "${i}" & + python client.py --partition-id "${i}" & done # This will allow you to use CTRL+C to stop all background processes diff --git a/examples/vertical-fl/pyproject.toml b/examples/vertical-fl/pyproject.toml index 14771c70062f..19dcd0e7a842 100644 --- a/examples/vertical-fl/pyproject.toml +++ b/examples/vertical-fl/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" name = "vertical-fl" version = "0.1.0" description = "PyTorch Vertical FL with Flower" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" diff --git a/examples/whisper-federated-finetuning/README.md b/examples/whisper-federated-finetuning/README.md index e89a09519fed..ddebe51247b2 100644 --- a/examples/whisper-federated-finetuning/README.md +++ b/examples/whisper-federated-finetuning/README.md @@ -110,7 +110,7 @@ An overview of the FL pipeline built with Flower for this example is illustrated 3. Once on-site training is completed, each client sends back the (now updated) classification head to the Flower server. 4. The Flower server aggregates (via FedAvg) the classification heads in order to obtain a new _global_ classification head. This head will be shared with clients in the next round. -Flower supports two ways of doing Federated Learning: simulated and non-simulated FL. The former, managed by the [`VirtualClientEngine`](https://flower.dev/docs/framework/how-to-run-simulations.html), allows you to run large-scale workloads in a system-aware manner, that scales with the resources available on your system (whether it is a laptop, a desktop with a single GPU, or a cluster of GPU servers). The latter is better suited for settings where clients are unique devices (e.g. a server, a smart device, etc). This example shows you how to use both. +Flower supports two ways of doing Federated Learning: simulated and non-simulated FL. The former, managed by the [`VirtualClientEngine`](https://flower.ai/docs/framework/how-to-run-simulations.html), allows you to run large-scale workloads in a system-aware manner, that scales with the resources available on your system (whether it is a laptop, a desktop with a single GPU, or a cluster of GPU servers). The latter is better suited for settings where clients are unique devices (e.g. a server, a smart device, etc). This example shows you how to use both. ### Preparing the dataset @@ -147,7 +147,7 @@ INFO flwr 2023-11-08 14:03:57,557 | app.py:229 | app_fit: metrics_centralized {' With just 5 FL rounds, the global model should be reaching ~95% validation accuracy. A test accuracy of 97% can be reached with 10 rounds of FL training using the default hyperparameters. On an RTX 3090Ti, each round takes ~20-30s depending on the amount of data the clients selected in a round have. -Take a look at the [Documentation](https://flower.dev/docs/framework/how-to-run-simulations.html) for more details on how you can customize your simulation. +Take a look at the [Documentation](https://flower.ai/docs/framework/how-to-run-simulations.html) for more details on how you can customize your simulation. ### Federated Finetuning (non-simulated) diff --git a/examples/whisper-federated-finetuning/pyproject.toml b/examples/whisper-federated-finetuning/pyproject.toml index dd5578b8b3d0..27a89578c5a0 100644 --- a/examples/whisper-federated-finetuning/pyproject.toml +++ b/examples/whisper-federated-finetuning/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" name = "whisper-flower" version = "0.1.0" description = "On-device Federated Downstreaming for Speech Classification" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" @@ -16,4 +16,4 @@ tokenizers = "0.13.3" datasets = "2.14.6" soundfile = "0.12.1" librosa = "0.10.1" -# this example was tested with pytorch 2.1.0 \ No newline at end of file +# this example was tested with pytorch 2.1.0 diff --git a/examples/whisper-federated-finetuning/requirements.txt b/examples/whisper-federated-finetuning/requirements.txt index eb4a5d7eb47b..f16b3d6993ce 100644 --- a/examples/whisper-federated-finetuning/requirements.txt +++ b/examples/whisper-federated-finetuning/requirements.txt @@ -3,5 +3,4 @@ tokenizers==0.13.3 datasets==2.14.6 soundfile==0.12.1 librosa==0.10.1 -flwr==1.5.0 -ray==2.6.3 \ No newline at end of file +flwr[simulation]>=1.0, <2.0 \ No newline at end of file diff --git a/examples/xgboost-comprehensive/README.md b/examples/xgboost-comprehensive/README.md index 97ecc39b47f2..dc6d7e3872d6 100644 --- a/examples/xgboost-comprehensive/README.md +++ b/examples/xgboost-comprehensive/README.md @@ -1,7 +1,7 @@ # Flower Example using XGBoost (Comprehensive) This example demonstrates a comprehensive federated learning setup using Flower with XGBoost. -We use [HIGGS](https://archive.ics.uci.edu/dataset/280/higgs) dataset to perform a binary classification task. This examples uses [Flower Datasets](https://flower.dev/docs/datasets/) to retrieve, partition and preprocess the data for each Flower client. +We use [HIGGS](https://archive.ics.uci.edu/dataset/280/higgs) dataset to perform a binary classification task. This examples uses [Flower Datasets](https://flower.ai/docs/datasets/) to retrieve, partition and preprocess the data for each Flower client. It differs from the [xgboost-quickstart](https://github.com/adap/flower/tree/main/examples/xgboost-quickstart) example in the following ways: - Arguments parsers of server and clients for hyperparameters selection. @@ -91,7 +91,7 @@ pip install -r requirements.txt ## Run Federated Learning with XGBoost and Flower -You can run this example in two ways: either by manually launching the server, and then several clients that connect to it; or by launching a Flower simulation. Both run the same workload, yielding identical results. The former is ideal for deployments on different machines, while the latter makes it easy to simulate large client cohorts in a resource-aware manner. You can read more about how Flower Simulation works in the [Documentation](https://flower.dev/docs/framework/how-to-run-simulations.html). The commands shown below assume you have activated your environment (if you decide to use Poetry, you can activate it via `poetry shell`). +You can run this example in two ways: either by manually launching the server, and then several clients that connect to it; or by launching a Flower simulation. Both run the same workload, yielding identical results. The former is ideal for deployments on different machines, while the latter makes it easy to simulate large client cohorts in a resource-aware manner. You can read more about how Flower Simulation works in the [Documentation](https://flower.ai/docs/framework/how-to-run-simulations.html). The commands shown below assume you have activated your environment (if you decide to use Poetry, you can activate it via `poetry shell`). ### Independent Client/Server Setup @@ -120,10 +120,10 @@ You can also run the example without the scripts. First, launch the server: python server.py --train-method=bagging/cyclic --pool-size=N --num-clients-per-round=N ``` -Then run at least two clients (each on a new terminal or computer in your network) passing different `NODE_ID` and all using the same `N` (denoting the total number of clients or data partitions): +Then run at least two clients (each on a new terminal or computer in your network) passing different `PARTITION_ID` and all using the same `N` (denoting the total number of clients or data partitions): ```bash -python client.py --train-method=bagging/cyclic --node-id=NODE_ID --num-partitions=N +python client.py --train-method=bagging/cyclic --partition-id=PARTITION_ID --num-partitions=N ``` ### Flower Simulation Setup @@ -143,7 +143,7 @@ python sim.py --train-method=cyclic --pool-size=5 --num-rounds=30 --centralised- ``` In addition, we provide more options to customise the experimental settings, including data partitioning and centralised/distributed evaluation (see `utils.py`). -Check the [tutorial](https://flower.dev/docs/framework/tutorial-quickstart-xgboost.html) for a detailed explanation. +Check the [tutorial](https://flower.ai/docs/framework/tutorial-quickstart-xgboost.html) for a detailed explanation. ### Expected Experimental Results diff --git a/examples/xgboost-comprehensive/client.py b/examples/xgboost-comprehensive/client.py index 74fbc4f5366a..66daed449fd5 100644 --- a/examples/xgboost-comprehensive/client.py +++ b/examples/xgboost-comprehensive/client.py @@ -35,9 +35,9 @@ resplitter=resplit, ) -# Load the partition for this `node_id` +# Load the partition for this `partition_id` log(INFO, "Loading partition...") -partition = fds.load_partition(node_id=args.node_id, split="train") +partition = fds.load_partition(partition_id=args.partition_id, split="train") partition.set_format("numpy") if args.centralised_eval: diff --git a/examples/xgboost-comprehensive/pyproject.toml b/examples/xgboost-comprehensive/pyproject.toml index e6495a98c969..b257801cb420 100644 --- a/examples/xgboost-comprehensive/pyproject.toml +++ b/examples/xgboost-comprehensive/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" name = "xgboost-comprehensive" version = "0.1.0" description = "Federated XGBoost with Flower (comprehensive)" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" diff --git a/examples/xgboost-comprehensive/run_bagging.sh b/examples/xgboost-comprehensive/run_bagging.sh index e853a4ef19cb..a6300b781a06 100755 --- a/examples/xgboost-comprehensive/run_bagging.sh +++ b/examples/xgboost-comprehensive/run_bagging.sh @@ -8,7 +8,7 @@ sleep 30 # Sleep for 30s to give the server enough time to start for i in `seq 0 4`; do echo "Starting client $i" - python3 client.py --node-id=$i --num-partitions=5 --partitioner-type=exponential & + python3 client.py --partition-id=$i --num-partitions=5 --partitioner-type=exponential & done # Enable CTRL+C to stop all background processes diff --git a/examples/xgboost-comprehensive/run_cyclic.sh b/examples/xgboost-comprehensive/run_cyclic.sh index 47e09fd8faef..258bdf2fe0d8 100755 --- a/examples/xgboost-comprehensive/run_cyclic.sh +++ b/examples/xgboost-comprehensive/run_cyclic.sh @@ -8,7 +8,7 @@ sleep 15 # Sleep for 15s to give the server enough time to start for i in `seq 0 4`; do echo "Starting client $i" - python3 client.py --node-id=$i --train-method=cyclic --num-partitions=5 --partitioner-type=exponential --centralised-eval & + python3 client.py --partition-id=$i --train-method=cyclic --num-partitions=5 --partitioner-type=exponential --centralised-eval & done # Enable CTRL+C to stop all background processes diff --git a/examples/xgboost-comprehensive/server.py b/examples/xgboost-comprehensive/server.py index c6986dc63ce4..2fecbcc65853 100644 --- a/examples/xgboost-comprehensive/server.py +++ b/examples/xgboost-comprehensive/server.py @@ -52,9 +52,9 @@ fraction_evaluate=1.0 if not centralised_eval else 0.0, on_evaluate_config_fn=eval_config, on_fit_config_fn=fit_config, - evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation - if not centralised_eval - else None, + evaluate_metrics_aggregation_fn=( + evaluate_metrics_aggregation if not centralised_eval else None + ), ) else: # Cyclic training diff --git a/examples/xgboost-comprehensive/sim.py b/examples/xgboost-comprehensive/sim.py index ec05b566dd95..b72b23931929 100644 --- a/examples/xgboost-comprehensive/sim.py +++ b/examples/xgboost-comprehensive/sim.py @@ -98,9 +98,9 @@ def main(): # Load and process all client partitions. This upfront cost is amortized soon # after the simulation begins since clients wont need to preprocess their partition. - for node_id in tqdm(range(args.pool_size), desc="Extracting client partition"): - # Extract partition for client with node_id - partition = fds.load_partition(node_id=node_id, split="train") + for partition_id in tqdm(range(args.pool_size), desc="Extracting client partition"): + # Extract partition for client with partition_id + partition = fds.load_partition(partition_id=partition_id, split="train") partition.set_format("numpy") if args.centralised_eval_client: @@ -124,21 +124,21 @@ def main(): if args.train_method == "bagging": # Bagging training strategy = FedXgbBagging( - evaluate_function=get_evaluate_fn(test_dmatrix) - if args.centralised_eval - else None, + evaluate_function=( + get_evaluate_fn(test_dmatrix) if args.centralised_eval else None + ), fraction_fit=(float(args.num_clients_per_round) / args.pool_size), min_fit_clients=args.num_clients_per_round, min_available_clients=args.pool_size, - min_evaluate_clients=args.num_evaluate_clients - if not args.centralised_eval - else 0, + min_evaluate_clients=( + args.num_evaluate_clients if not args.centralised_eval else 0 + ), fraction_evaluate=1.0 if not args.centralised_eval else 0.0, on_evaluate_config_fn=eval_config, on_fit_config_fn=fit_config, - evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation - if not args.centralised_eval - else None, + evaluate_metrics_aggregation_fn=( + evaluate_metrics_aggregation if not args.centralised_eval else None + ), ) else: # Cyclic training diff --git a/examples/xgboost-comprehensive/utils.py b/examples/xgboost-comprehensive/utils.py index 102587f4266d..abc100da1ade 100644 --- a/examples/xgboost-comprehensive/utils.py +++ b/examples/xgboost-comprehensive/utils.py @@ -37,10 +37,10 @@ def client_args_parser(): help="Partitioner types.", ) parser.add_argument( - "--node-id", + "--partition-id", default=0, type=int, - help="Node ID used for the current client.", + help="Partition ID used for the current client.", ) parser.add_argument( "--seed", default=42, type=int, help="Seed used for train/test splitting." diff --git a/examples/xgboost-quickstart/README.md b/examples/xgboost-quickstart/README.md index 5174c236c668..72dde5706e8d 100644 --- a/examples/xgboost-quickstart/README.md +++ b/examples/xgboost-quickstart/README.md @@ -67,13 +67,13 @@ To do so simply open two more terminal windows and run the following commands. Start client 1 in the first terminal: ```shell -python3 client.py --node-id=0 +python3 client.py --partition-id=0 ``` Start client 2 in the second terminal: ```shell -python3 client.py --node-id=1 +python3 client.py --partition-id=1 ``` You will see that XGBoost is starting a federated training. @@ -85,4 +85,4 @@ poetry run ./run.sh ``` Look at the [code](https://github.com/adap/flower/tree/main/examples/xgboost-quickstart) -and [tutorial](https://flower.dev/docs/framework/tutorial-quickstart-xgboost.html) for a detailed explanation. +and [tutorial](https://flower.ai/docs/framework/tutorial-quickstart-xgboost.html) for a detailed explanation. diff --git a/examples/xgboost-quickstart/client.py b/examples/xgboost-quickstart/client.py index 62e8a441bae1..6ac23ae15148 100644 --- a/examples/xgboost-quickstart/client.py +++ b/examples/xgboost-quickstart/client.py @@ -24,13 +24,13 @@ warnings.filterwarnings("ignore", category=UserWarning) -# Define arguments parser for the client/node ID. +# Define arguments parser for the client/partition ID. parser = argparse.ArgumentParser() parser.add_argument( - "--node-id", + "--partition-id", default=0, type=int, - help="Node ID used for the current client.", + help="Partition ID used for the current client.", ) args = parser.parse_args() @@ -61,9 +61,9 @@ def transform_dataset_to_dmatrix(data: Union[Dataset, DatasetDict]) -> xgb.core. partitioner = IidPartitioner(num_partitions=30) fds = FederatedDataset(dataset="jxie/higgs", partitioners={"train": partitioner}) -# Load the partition for this `node_id` +# Load the partition for this `partition_id` log(INFO, "Loading partition...") -partition = fds.load_partition(node_id=args.node_id, split="train") +partition = fds.load_partition(partition_id=args.partition_id, split="train") partition.set_format("numpy") # Train/test splitting diff --git a/examples/xgboost-quickstart/pyproject.toml b/examples/xgboost-quickstart/pyproject.toml index af0164514cf1..c16542ea7ffe 100644 --- a/examples/xgboost-quickstart/pyproject.toml +++ b/examples/xgboost-quickstart/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" name = "xgboost-quickstart" version = "0.1.0" description = "Federated XGBoost with Flower (quickstart)" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" diff --git a/examples/xgboost-quickstart/run.sh b/examples/xgboost-quickstart/run.sh index 6287145bfb5f..b35af58222ab 100755 --- a/examples/xgboost-quickstart/run.sh +++ b/examples/xgboost-quickstart/run.sh @@ -8,7 +8,7 @@ sleep 5 # Sleep for 5s to give the server enough time to start for i in `seq 0 1`; do echo "Starting client $i" - python3 client.py --node-id=$i & + python3 client.py --partition-id=$i & done # Enable CTRL+C to stop all background processes diff --git a/pyproject.toml b/pyproject.toml index 16688268f7be..b52be36b7ff5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,11 +7,11 @@ name = "flwr" version = "1.8.0" description = "Flower: A Friendly Federated Learning Framework" license = "Apache-2.0" -authors = ["The Flower Authors "] +authors = ["The Flower Authors "] readme = "README.md" -homepage = "https://flower.dev" +homepage = "https://flower.ai" repository = "https://github.com/adap/flower" -documentation = "https://flower.dev" +documentation = "https://flower.ai" keywords = [ "flower", "fl", @@ -52,6 +52,7 @@ exclude = [ ] [tool.poetry.scripts] +flwr = "flwr.cli.app:app" flower-driver-api = "flwr.server:run_driver_api" flower-fleet-api = "flwr.server:run_fleet_api" flower-superlink = "flwr.server:run_superlink" @@ -64,9 +65,10 @@ python = "^3.8" numpy = "^1.21.0" grpcio = "^1.60.0" protobuf = "^4.25.2" -cryptography = "^41.0.2" +cryptography = "^42.0.4" pycryptodome = "^3.18.0" iterators = "^0.0.2" +typer = { version = "^0.9.0", extras=["all"] } # Optional dependencies (VCE) ray = { version = "==2.6.3", optional = true } pydantic = { version = "<2.0.0", optional = true } @@ -86,14 +88,14 @@ types-requests = "==2.31.0.20240125" types-setuptools = "==69.0.0.20240125" clang-format = "==17.0.6" isort = "==5.13.2" -black = { version = "==23.10.1", extras = ["jupyter"] } +black = { version = "==24.2.0", extras = ["jupyter"] } docformatter = "==1.7.5" mypy = "==1.8.0" pylint = "==3.0.3" flake8 = "==5.0.4" pytest = "==7.4.4" pytest-cov = "==4.1.0" -pytest-watch = "==4.2.0" +pytest-watcher = "==0.4.1" grpcio-tools = "==1.60.0" mypy-protobuf = "==3.2.0" jupyterlab = "==4.0.12" @@ -147,6 +149,16 @@ testpaths = [ "src/py/flwr", "src/py/flwr_tool", ] +filterwarnings = "ignore::DeprecationWarning" + +[tool.pytest-watcher] +now = false +clear = true +delay = 0.2 +runner = "pytest" +runner_args = ["-s", "-vvvvv"] +patterns = ["*.py"] +ignore_patterns = [] [tool.mypy] plugins = [ diff --git a/src/py/flwr/cli/__init__.py b/src/py/flwr/cli/__init__.py new file mode 100644 index 000000000000..d4d3b8ac4d48 --- /dev/null +++ b/src/py/flwr/cli/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower command line interface.""" diff --git a/src/py/flwr/cli/app.py b/src/py/flwr/cli/app.py new file mode 100644 index 000000000000..dc390de03547 --- /dev/null +++ b/src/py/flwr/cli/app.py @@ -0,0 +1,35 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower command line interface.""" + +import typer + +from .example import example +from .new import new + +app = typer.Typer( + help=typer.style( + "flwr is the Flower command line interface.", + fg=typer.colors.BRIGHT_YELLOW, + bold=True, + ), + no_args_is_help=True, +) + +app.command()(new) +app.command()(example) + +if __name__ == "__main__": + app() diff --git a/src/py/flwr/cli/example.py b/src/py/flwr/cli/example.py new file mode 100644 index 000000000000..625ca8729640 --- /dev/null +++ b/src/py/flwr/cli/example.py @@ -0,0 +1,64 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower command line interface `example` command.""" + +import json +import os +import subprocess +import tempfile +import urllib.request + +from .utils import prompt_options + + +def example() -> None: + """Clone a Flower example. + + All examples available in the Flower repository are available through this command. + """ + # Load list of examples directly from GitHub + url = "https://api.github.com/repos/adap/flower/git/trees/main" + with urllib.request.urlopen(url) as res: + data = json.load(res) + examples_directory_url = [ + item["url"] for item in data["tree"] if item["path"] == "examples" + ][0] + + with urllib.request.urlopen(examples_directory_url) as res: + data = json.load(res) + example_names = [ + item["path"] for item in data["tree"] if item["path"] not in [".gitignore"] + ] + + example_name = prompt_options( + "Please select example by typing in the number", + example_names, + ) + + with tempfile.TemporaryDirectory() as tmpdirname: + subprocess.check_output( + [ + "git", + "clone", + "--depth=1", + "https://github.com/adap/flower.git", + tmpdirname, + ] + ) + examples_dir = os.path.join(tmpdirname, "examples", example_name) + subprocess.check_output(["mv", examples_dir, "."]) + + print() + print(f"Example ready to use in {os.path.join(os.getcwd(), example_name)}") diff --git a/src/py/flwr/cli/new/__init__.py b/src/py/flwr/cli/new/__init__.py new file mode 100644 index 000000000000..a973f47021c3 --- /dev/null +++ b/src/py/flwr/cli/new/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower command line interface `new` command.""" + +from .new import new as new + +__all__ = [ + "new", +] diff --git a/src/py/flwr/cli/new/new.py b/src/py/flwr/cli/new/new.py new file mode 100644 index 000000000000..d5db6091344d --- /dev/null +++ b/src/py/flwr/cli/new/new.py @@ -0,0 +1,130 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower command line interface `new` command.""" + +import os +from enum import Enum +from string import Template +from typing import Dict, Optional + +import typer +from typing_extensions import Annotated + +from ..utils import prompt_options + + +class MlFramework(str, Enum): + """Available frameworks.""" + + PYTORCH = "PyTorch" + TENSORFLOW = "TensorFlow" + + +class TemplateNotFound(Exception): + """Raised when template does not exist.""" + + +def load_template(name: str) -> str: + """Load template from template directory and return as text.""" + tpl_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "templates")) + tpl_file_path = os.path.join(tpl_dir, name) + + if not os.path.isfile(tpl_file_path): + raise TemplateNotFound(f"Template '{name}' not found") + + with open(tpl_file_path, encoding="utf-8") as tpl_file: + return tpl_file.read() + + +def render_template(template: str, data: Dict[str, str]) -> str: + """Render template.""" + tpl_file = load_template(template) + tpl = Template(tpl_file) + result = tpl.substitute(data) + return result + + +def create_file(file_path: str, content: str) -> None: + """Create file including all nessecary directories and write content into file.""" + os.makedirs(os.path.dirname(file_path), exist_ok=True) + with open(file_path, "w", encoding="utf-8") as f: + f.write(content) + + +def render_and_create(file_path: str, template: str, context: Dict[str, str]) -> None: + """Render template and write to file.""" + content = render_template(template, context) + create_file(file_path, content) + + +def new( + project_name: Annotated[ + str, + typer.Argument(metavar="project_name", help="The name of the project"), + ], + framework: Annotated[ + Optional[MlFramework], + typer.Option(case_sensitive=False, help="The ML framework to use"), + ] = None, +) -> None: + """Create new Flower project.""" + print(f"Creating Flower project {project_name}...") + + if framework is not None: + framework_str = str(framework.value) + else: + framework_value = prompt_options( + "Please select ML framework by typing in the number", + [mlf.value for mlf in MlFramework], + ) + selected_value = [ + name + for name, value in vars(MlFramework).items() + if value == framework_value + ] + framework_str = selected_value[0] + + # Set project directory path + cwd = os.getcwd() + pnl = project_name.lower() + project_dir = os.path.join(cwd, pnl) + + # List of files to render + files = { + "README.md": { + "template": "app/README.md.tpl", + }, + "requirements.txt": { + "template": f"app/requirements.{framework_str.lower()}.txt.tpl" + }, + "flower.toml": {"template": "app/flower.toml.tpl"}, + f"{pnl}/__init__.py": {"template": "app/code/__init__.py.tpl"}, + f"{pnl}/server.py": { + "template": f"app/code/server.{framework_str.lower()}.py.tpl" + }, + f"{pnl}/client.py": { + "template": f"app/code/client.{framework_str.lower()}.py.tpl" + }, + } + context = {"project_name": project_name} + + for file_path, value in files.items(): + render_and_create( + file_path=os.path.join(project_dir, file_path), + template=value["template"], + context=context, + ) + + print("Project creation successful.") diff --git a/src/py/flwr/cli/new/new_test.py b/src/py/flwr/cli/new/new_test.py new file mode 100644 index 000000000000..7a4832013b0c --- /dev/null +++ b/src/py/flwr/cli/new/new_test.py @@ -0,0 +1,99 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test for Flower command line interface `new` command.""" + +import os + +from .new import MlFramework, create_file, load_template, new, render_template + + +def test_load_template() -> None: + """Test if load_template returns a string.""" + # Prepare + filename = "app/README.md.tpl" + + # Execute + text = load_template(filename) + + # Assert + assert isinstance(text, str) + + +def test_render_template() -> None: + """Test if a string is correctly substituted.""" + # Prepare + filename = "app/README.md.tpl" + data = {"project_name": "FedGPT"} + + # Execute + result = render_template(filename, data) + + # Assert + assert "# FedGPT" in result + + +def test_create_file(tmp_path: str) -> None: + """Test if file with content is created.""" + # Prepare + file_path = os.path.join(tmp_path, "test.txt") + content = "Foobar" + + # Execute + create_file(file_path, content) + + # Assert + with open(file_path, encoding="utf-8") as f: + text = f.read() + + assert text == "Foobar" + + +def test_new(tmp_path: str) -> None: + """Test if project is created for framework.""" + # Prepare + project_name = "FedGPT" + framework = MlFramework.PYTORCH + expected_files_top_level = { + "requirements.txt", + "fedgpt", + "README.md", + "flower.toml", + } + expected_files_module = { + "__init__.py", + "server.py", + "client.py", + } + + # Current directory + origin = os.getcwd() + + try: + # Change into the temprorary directory + os.chdir(tmp_path) + + # Execute + new(project_name=project_name, framework=framework) + + # Assert + file_list = os.listdir(os.path.join(tmp_path, project_name.lower())) + assert set(file_list) == expected_files_top_level + + file_list = os.listdir( + os.path.join(tmp_path, project_name.lower(), project_name.lower()) + ) + assert set(file_list) == expected_files_module + finally: + os.chdir(origin) diff --git a/src/py/flwr/cli/new/templates/__init__.py b/src/py/flwr/cli/new/templates/__init__.py new file mode 100644 index 000000000000..7a951c2da1a2 --- /dev/null +++ b/src/py/flwr/cli/new/templates/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower CLI `new` command templates.""" diff --git a/src/py/flwr/cli/new/templates/app/README.md.tpl b/src/py/flwr/cli/new/templates/app/README.md.tpl new file mode 100644 index 000000000000..7904fa8d3a3c --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/README.md.tpl @@ -0,0 +1,33 @@ +# $project_name + +## Install dependencies + +```bash +pip install -r requirements.txt +``` + +## Start the SuperLink + +```bash +flower-superlink --insecure +``` + +## Start the long-running Flower client + +In a new terminal window, start the first long-running Flower client: + +```bash +flower-client-app client:app --insecure +``` + +In yet another new terminal window, start the second long-running Flower client: + +```bash +flower-client-app client:app --insecure +``` + +## Start the ServerApp + +```bash +flower-server-app server:app --insecure +``` diff --git a/src/py/flwr/cli/new/templates/app/__init__.py b/src/py/flwr/cli/new/templates/app/__init__.py new file mode 100644 index 000000000000..617628fc9138 --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower CLI `new` command app templates.""" diff --git a/src/py/flwr/cli/new/templates/app/code/__init__.py b/src/py/flwr/cli/new/templates/app/code/__init__.py new file mode 100644 index 000000000000..7f1a0e9f4fa2 --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/code/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower CLI `new` command app / code templates.""" diff --git a/src/py/flwr/cli/new/templates/app/code/__init__.py.tpl b/src/py/flwr/cli/new/templates/app/code/__init__.py.tpl new file mode 100644 index 000000000000..57998c81efb8 --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/code/__init__.py.tpl @@ -0,0 +1 @@ +"""$project_name.""" diff --git a/src/py/flwr/cli/new/templates/app/code/client.pytorch.py.tpl b/src/py/flwr/cli/new/templates/app/code/client.pytorch.py.tpl new file mode 100644 index 000000000000..006d00f75e40 --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/code/client.pytorch.py.tpl @@ -0,0 +1 @@ +"""$project_name: A Flower / PyTorch app.""" diff --git a/src/py/flwr/cli/new/templates/app/code/client.tensorflow.py.tpl b/src/py/flwr/cli/new/templates/app/code/client.tensorflow.py.tpl new file mode 100644 index 000000000000..cc00f8ff0b8c --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/code/client.tensorflow.py.tpl @@ -0,0 +1 @@ +"""$project_name: A Flower / TensorFlow app.""" diff --git a/src/py/flwr/cli/new/templates/app/code/server.pytorch.py.tpl b/src/py/flwr/cli/new/templates/app/code/server.pytorch.py.tpl new file mode 100644 index 000000000000..006d00f75e40 --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/code/server.pytorch.py.tpl @@ -0,0 +1 @@ +"""$project_name: A Flower / PyTorch app.""" diff --git a/src/py/flwr/cli/new/templates/app/code/server.tensorflow.py.tpl b/src/py/flwr/cli/new/templates/app/code/server.tensorflow.py.tpl new file mode 100644 index 000000000000..cc00f8ff0b8c --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/code/server.tensorflow.py.tpl @@ -0,0 +1 @@ +"""$project_name: A Flower / TensorFlow app.""" diff --git a/src/py/flwr/cli/new/templates/app/flower.toml.tpl b/src/py/flwr/cli/new/templates/app/flower.toml.tpl new file mode 100644 index 000000000000..4dd7117bc3a3 --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/flower.toml.tpl @@ -0,0 +1,10 @@ +[flower] +name = "$project_name" +version = "1.0.0" +description = "" +license = "Apache-2.0" +authors = ["The Flower Authors "] + +[components] +serverapp = "$project_name.server:app" +clientapp = "$project_name.client:app" diff --git a/src/py/flwr/cli/new/templates/app/requirements.pytorch.txt.tpl b/src/py/flwr/cli/new/templates/app/requirements.pytorch.txt.tpl new file mode 100644 index 000000000000..d9426e0b62c0 --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/requirements.pytorch.txt.tpl @@ -0,0 +1,4 @@ +flwr>=1.8, <2.0 +flwr-datasets[vision]>=0.0.2, <1.0.0 +torch==1.13.1 +torchvision==0.14.1 diff --git a/src/py/flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl b/src/py/flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl new file mode 100644 index 000000000000..4fe7bfdc1e89 --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl @@ -0,0 +1,4 @@ +flwr>=1.8, <2.0 +flwr-datasets[vision]>=0.0.2, <1.0.0 +tensorflow-macos>=2.9.1, != 2.11.1 ; sys_platform == "darwin" and platform_machine == "arm64" +tensorflow-cpu>=2.9.1, != 2.11.1 ; platform_machine == "x86_64" diff --git a/src/py/flwr/cli/utils.py b/src/py/flwr/cli/utils.py new file mode 100644 index 000000000000..d61189ffc4e3 --- /dev/null +++ b/src/py/flwr/cli/utils.py @@ -0,0 +1,54 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower command line interface utils.""" + +from typing import List + +import typer + + +def prompt_options(text: str, options: List[str]) -> str: + """Ask user to select one of the given options and return the selected item.""" + # Turn options into a list with index as in " [ 0] quickstart-pytorch" + options_formatted = [ + " [ " + + typer.style(index, fg=typer.colors.GREEN, bold=True) + + "]" + + f" {typer.style(name, fg=typer.colors.WHITE, bold=True)}" + for index, name in enumerate(options) + ] + + while True: + index = typer.prompt( + "\n" + + typer.style(f"💬 {text}", fg=typer.colors.MAGENTA, bold=True) + + "\n\n" + + "\n".join(options_formatted) + + "\n\n\n" + ) + try: + options[int(index)] # pylint: disable=expression-not-assigned + break + except IndexError: + print(typer.style("❌ Index out of range", fg=typer.colors.RED, bold=True)) + continue + except ValueError: + print( + typer.style("❌ Please choose a number", fg=typer.colors.RED, bold=True) + ) + continue + + result = options[int(index)] + return result diff --git a/src/py/flwr/client/__init__.py b/src/py/flwr/client/__init__.py index f359fb472cbe..a721fb584164 100644 --- a/src/py/flwr/client/__init__.py +++ b/src/py/flwr/client/__init__.py @@ -19,7 +19,7 @@ from .app import start_client as start_client from .app import start_numpy_client as start_numpy_client from .client import Client as Client -from .clientapp import ClientApp as ClientApp +from .client_app import ClientApp as ClientApp from .numpy_client import NumPyClient as NumPyClient from .typing import ClientFn as ClientFn diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 15f7c5057a20..93d654379cfc 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -23,9 +23,9 @@ from typing import Callable, ContextManager, Optional, Tuple, Union from flwr.client.client import Client -from flwr.client.clientapp import ClientApp +from flwr.client.client_app import ClientApp from flwr.client.typing import ClientFn -from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event +from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, Message, event from flwr.common.address import parse_address from flwr.common.constant import ( MISSING_EXTRA_REST, @@ -34,10 +34,10 @@ TRANSPORT_TYPE_REST, TRANSPORT_TYPES, ) +from flwr.common.exit_handlers import register_exit_handlers from flwr.common.logger import log, warn_deprecated_feature, warn_experimental_feature -from flwr.common.message import Message -from .clientapp import load_client_app +from .client_app import load_client_app from .grpc_client.connection import grpc_connection from .grpc_rere_client.connection import grpc_request_response from .message_handler.message_handler import handle_control_message @@ -105,7 +105,7 @@ def _load() -> ClientApp: root_certificates=root_certificates, insecure=args.insecure, ) - event(EventType.RUN_CLIENT_APP_LEAVE) + register_exit_handlers(event_type=EventType.RUN_CLIENT_APP_LEAVE) def _parse_args_run_client_app() -> argparse.ArgumentParser: @@ -507,9 +507,7 @@ def start_numpy_client( ) -def _init_connection( - transport: Optional[str], server_address: str -) -> Tuple[ +def _init_connection(transport: Optional[str], server_address: str) -> Tuple[ Callable[ [str, bool, int, Union[bytes, str, None]], ContextManager[ diff --git a/src/py/flwr/client/client.py b/src/py/flwr/client/client.py index 6d982ecc9a9e..23a3755f3efe 100644 --- a/src/py/flwr/client/client.py +++ b/src/py/flwr/client/client.py @@ -21,6 +21,7 @@ from flwr.common import ( Code, + Context, EvaluateIns, EvaluateRes, FitIns, @@ -32,7 +33,6 @@ Parameters, Status, ) -from flwr.common.context import Context class Client(ABC): diff --git a/src/py/flwr/client/clientapp.py b/src/py/flwr/client/client_app.py similarity index 94% rename from src/py/flwr/client/clientapp.py rename to src/py/flwr/client/client_app.py index 51c912890c7e..9de6516c7a39 100644 --- a/src/py/flwr/client/clientapp.py +++ b/src/py/flwr/client/client_app.py @@ -19,12 +19,11 @@ from typing import List, Optional, cast from flwr.client.message_handler.message_handler import ( - handle_legacy_message_from_tasktype, + handle_legacy_message_from_msgtype, ) from flwr.client.mod.utils import make_ffn from flwr.client.typing import ClientFn, Mod -from flwr.common.context import Context -from flwr.common.message import Message +from flwr.common import Context, Message class ClientApp: @@ -63,7 +62,7 @@ def ffn( message: Message, context: Context, ) -> Message: # pylint: disable=invalid-name - out_message = handle_legacy_message_from_tasktype( + out_message = handle_legacy_message_from_msgtype( client_fn=client_fn, message=message, context=context ) return out_message @@ -72,12 +71,12 @@ def ffn( self._call = make_ffn(ffn, mods if mods is not None else []) def __call__(self, message: Message, context: Context) -> Message: - """.""" + """Execute `ClientApp`.""" return self._call(message, context) class LoadClientAppError(Exception): - """.""" + """Error when trying to load `ClientApp`.""" def load_client_app(module_attribute_str: str) -> ClientApp: diff --git a/src/py/flwr/client/grpc_client/connection.py b/src/py/flwr/client/grpc_client/connection.py index bd1ea5fab307..3561626dcb39 100644 --- a/src/py/flwr/client/grpc_client/connection.py +++ b/src/py/flwr/client/grpc_client/connection.py @@ -22,20 +22,23 @@ from queue import Queue from typing import Callable, Iterator, Optional, Tuple, Union, cast -from flwr.common import GRPC_MAX_MESSAGE_LENGTH +from flwr.common import ( + GRPC_MAX_MESSAGE_LENGTH, + ConfigsRecord, + Message, + Metadata, + RecordSet, +) from flwr.common import recordset_compat as compat from flwr.common import serde -from flwr.common.configsrecord import ConfigsRecord from flwr.common.constant import ( - TASK_TYPE_EVALUATE, - TASK_TYPE_FIT, - TASK_TYPE_GET_PARAMETERS, - TASK_TYPE_GET_PROPERTIES, + MESSAGE_TYPE_EVALUATE, + MESSAGE_TYPE_FIT, + MESSAGE_TYPE_GET_PARAMETERS, + MESSAGE_TYPE_GET_PROPERTIES, ) from flwr.common.grpc import create_channel from flwr.common.logger import log -from flwr.common.message import Message, Metadata -from flwr.common.recordset import RecordSet from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 ClientMessage, Reason, @@ -133,33 +136,33 @@ def receive() -> Message: # ServerMessage proto --> *Ins --> RecordSet field = proto.WhichOneof("msg") - task_type = "" + message_type = "" if field == "get_properties_ins": recordset = compat.getpropertiesins_to_recordset( serde.get_properties_ins_from_proto(proto.get_properties_ins) ) - task_type = TASK_TYPE_GET_PROPERTIES + message_type = MESSAGE_TYPE_GET_PROPERTIES elif field == "get_parameters_ins": recordset = compat.getparametersins_to_recordset( serde.get_parameters_ins_from_proto(proto.get_parameters_ins) ) - task_type = TASK_TYPE_GET_PARAMETERS + message_type = MESSAGE_TYPE_GET_PARAMETERS elif field == "fit_ins": recordset = compat.fitins_to_recordset( serde.fit_ins_from_proto(proto.fit_ins), False ) - task_type = TASK_TYPE_FIT + message_type = MESSAGE_TYPE_FIT elif field == "evaluate_ins": recordset = compat.evaluateins_to_recordset( serde.evaluate_ins_from_proto(proto.evaluate_ins), False ) - task_type = TASK_TYPE_EVALUATE + message_type = MESSAGE_TYPE_EVALUATE elif field == "reconnect_ins": recordset = RecordSet() - recordset.set_configs( - "config", ConfigsRecord({"seconds": proto.reconnect_ins.seconds}) + recordset.configs_records["config"] = ConfigsRecord( + {"seconds": proto.reconnect_ins.seconds} ) - task_type = "reconnect" + message_type = "reconnect" else: raise ValueError( "Unsupported instruction in ServerMessage, " @@ -170,43 +173,48 @@ def receive() -> Message: return Message( metadata=Metadata( run_id=0, - task_id=str(uuid.uuid4()), + message_id=str(uuid.uuid4()), + src_node_id=0, + dst_node_id=0, + reply_to_message="", group_id="", ttl="", - task_type=task_type, + message_type=message_type, ), - message=recordset, + content=recordset, ) def send(message: Message) -> None: - # Retrieve RecordSet and task_type - recordset = message.message - task_type = message.metadata.task_type + # Retrieve RecordSet and message_type + recordset = message.content + message_type = message.metadata.message_type # RecordSet --> *Res --> *Res proto -> ClientMessage proto - if task_type == TASK_TYPE_GET_PROPERTIES: + if message_type == MESSAGE_TYPE_GET_PROPERTIES: getpropres = compat.recordset_to_getpropertiesres(recordset) msg_proto = ClientMessage( get_properties_res=serde.get_properties_res_to_proto(getpropres) ) - elif task_type == TASK_TYPE_GET_PARAMETERS: + elif message_type == MESSAGE_TYPE_GET_PARAMETERS: getparamres = compat.recordset_to_getparametersres(recordset, False) msg_proto = ClientMessage( get_parameters_res=serde.get_parameters_res_to_proto(getparamres) ) - elif task_type == TASK_TYPE_FIT: + elif message_type == MESSAGE_TYPE_FIT: fitres = compat.recordset_to_fitres(recordset, False) msg_proto = ClientMessage(fit_res=serde.fit_res_to_proto(fitres)) - elif task_type == TASK_TYPE_EVALUATE: + elif message_type == MESSAGE_TYPE_EVALUATE: evalres = compat.recordset_to_evaluateres(recordset) msg_proto = ClientMessage(evaluate_res=serde.evaluate_res_to_proto(evalres)) - elif task_type == "reconnect": - reason = cast(Reason.ValueType, recordset.get_configs("config")["reason"]) + elif message_type == "reconnect": + reason = cast( + Reason.ValueType, recordset.configs_records["config"]["reason"] + ) msg_proto = ClientMessage( disconnect_res=ClientMessage.DisconnectRes(reason=reason) ) else: - raise ValueError(f"Invalid task type: {task_type}") + raise ValueError(f"Invalid message type: {message_type}") # Send ClientMessage proto return queue.put(msg_proto, block=False) diff --git a/src/py/flwr/client/grpc_client/connection_test.py b/src/py/flwr/client/grpc_client/connection_test.py index 54fd41d79325..30bff068b60a 100644 --- a/src/py/flwr/client/grpc_client/connection_test.py +++ b/src/py/flwr/client/grpc_client/connection_test.py @@ -23,11 +23,9 @@ import grpc +from flwr.common import ConfigsRecord, Message, Metadata, RecordSet from flwr.common import recordset_compat as compat -from flwr.common.configsrecord import ConfigsRecord -from flwr.common.constant import TASK_TYPE_GET_PROPERTIES -from flwr.common.message import Message, Metadata -from flwr.common.recordset import RecordSet +from flwr.common.constant import MESSAGE_TYPE_GET_PROPERTIES from flwr.common.typing import Code, GetPropertiesRes, Status from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 ClientMessage, @@ -46,24 +44,30 @@ MESSAGE_GET_PROPERTIES = Message( metadata=Metadata( run_id=0, - task_id="", + message_id="", + src_node_id=0, + dst_node_id=0, + reply_to_message="", group_id="", ttl="", - task_type=TASK_TYPE_GET_PROPERTIES, + message_type=MESSAGE_TYPE_GET_PROPERTIES, ), - message=compat.getpropertiesres_to_recordset( + content=compat.getpropertiesres_to_recordset( GetPropertiesRes(Status(Code.OK, ""), {}) ), ) MESSAGE_DISCONNECT = Message( metadata=Metadata( run_id=0, - task_id="", + message_id="", + src_node_id=0, + dst_node_id=0, + reply_to_message="", group_id="", ttl="", - task_type="reconnect", + message_type="reconnect", ), - message=RecordSet(configs={"config": ConfigsRecord({"reason": 0})}), + content=RecordSet(configs_records={"config": ConfigsRecord({"reason": 0})}), ) @@ -132,7 +136,7 @@ def run_client() -> int: message = receive() messages_received += 1 - if message.metadata.task_type == "reconnect": # type: ignore + if message.metadata.message_type == "reconnect": # type: ignore send(MESSAGE_DISCONNECT) break diff --git a/src/py/flwr/client/grpc_rere_client/connection.py b/src/py/flwr/client/grpc_rere_client/connection.py index 07635d002721..00b7a864c5d6 100644 --- a/src/py/flwr/client/grpc_rere_client/connection.py +++ b/src/py/flwr/client/grpc_rere_client/connection.py @@ -16,20 +16,17 @@ from contextlib import contextmanager +from copy import copy from logging import DEBUG, ERROR from pathlib import Path from typing import Callable, Dict, Iterator, Optional, Tuple, Union, cast -from flwr.client.message_handler.task_handler import ( - configure_task_res, - get_task_ins, - validate_task_ins, - validate_task_res, -) +from flwr.client.message_handler.message_handler import validate_out_message +from flwr.client.message_handler.task_handler import get_task_ins, validate_task_ins from flwr.common import GRPC_MAX_MESSAGE_LENGTH from flwr.common.grpc import create_channel from flwr.common.logger import log, warn_experimental_feature -from flwr.common.message import Message +from flwr.common.message import Message, Metadata from flwr.common.serde import message_from_taskins, message_to_taskres from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, @@ -42,7 +39,7 @@ from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611 KEY_NODE = "node" -KEY_TASK_INS = "current_task_ins" +KEY_METADATA = "in_message_metadata" def on_channel_state_change(channel_connectivity: str) -> None: @@ -103,8 +100,8 @@ def grpc_request_response( channel.subscribe(on_channel_state_change) stub = FleetStub(channel) - # Necessary state to link TaskRes to TaskIns - state: Dict[str, Optional[TaskIns]] = {KEY_TASK_INS: None} + # Necessary state to validate messages to be sent + state: Dict[str, Optional[Metadata]] = {KEY_METADATA: None} # Enable create_node and delete_node to store node node_store: Dict[str, Optional[Node]] = {KEY_NODE: None} @@ -150,14 +147,20 @@ def receive() -> Optional[Message]: task_ins: Optional[TaskIns] = get_task_ins(response) # Discard the current TaskIns if not valid - if task_ins is not None and not validate_task_ins(task_ins): + if task_ins is not None and not ( + task_ins.task.consumer.node_id == node.node_id + and validate_task_ins(task_ins) + ): task_ins = None - # Remember `task_ins` until `task_res` is available - state[KEY_TASK_INS] = task_ins + # Construct the Message + in_message = message_from_taskins(task_ins) if task_ins else None + + # Remember `metadata` of the in message + state[KEY_METADATA] = copy(in_message.metadata) if in_message else None # Return the message if available - return message_from_taskins(task_ins) if task_ins is not None else None + return in_message def send(message: Message) -> None: """Send task result back to server.""" @@ -165,30 +168,26 @@ def send(message: Message) -> None: if node_store[KEY_NODE] is None: log(ERROR, "Node instance missing") return - node: Node = cast(Node, node_store[KEY_NODE]) - # Get incoming TaskIns - if state[KEY_TASK_INS] is None: - log(ERROR, "No current TaskIns") + # Get incoming message + in_metadata = state[KEY_METADATA] + if in_metadata is None: + log(ERROR, "No current message") + return + + # Validate out message + if not validate_out_message(message, in_metadata): + log(ERROR, "Invalid out message") return - task_ins: TaskIns = cast(TaskIns, state[KEY_TASK_INS]) # Construct TaskRes task_res = message_to_taskres(message) - # Check if fields to be set are not initialized - if not validate_task_res(task_res): - state[KEY_TASK_INS] = None - log(ERROR, "TaskRes has been initialized accidentally") - - # Configure TaskRes - task_res = configure_task_res(task_res, task_ins, node) - # Serialize ProtoBuf to bytes request = PushTaskResRequest(task_res_list=[task_res]) _ = stub.PushTaskRes(request) - state[KEY_TASK_INS] = None + state[KEY_METADATA] = None try: # Yield methods diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index 451da089eaf0..87cace88ec27 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -15,6 +15,7 @@ """Client-side message handler.""" +from logging import WARN from typing import Optional, Tuple, cast from flwr.client.client import ( @@ -23,17 +24,15 @@ maybe_call_get_parameters, maybe_call_get_properties, ) +from flwr.client.numpy_client import NumPyClient from flwr.client.typing import ClientFn -from flwr.common.configsrecord import ConfigsRecord +from flwr.common import ConfigsRecord, Context, Message, Metadata, RecordSet, log from flwr.common.constant import ( - TASK_TYPE_EVALUATE, - TASK_TYPE_FIT, - TASK_TYPE_GET_PARAMETERS, - TASK_TYPE_GET_PROPERTIES, + MESSAGE_TYPE_EVALUATE, + MESSAGE_TYPE_FIT, + MESSAGE_TYPE_GET_PARAMETERS, + MESSAGE_TYPE_GET_PROPERTIES, ) -from flwr.common.context import Context -from flwr.common.message import Message, Metadata -from flwr.common.recordset import RecordSet from flwr.common.recordset_compat import ( evaluateres_to_recordset, fitres_to_recordset, @@ -75,10 +74,10 @@ def handle_control_message(message: Message) -> Tuple[Optional[Message], int]: sleep_duration : int Number of seconds that the client should disconnect from the server. """ - if message.metadata.task_type == "reconnect": + if message.metadata.message_type == "reconnect": # Retrieve ReconnectIns from recordset - recordset = message.message - seconds = cast(int, recordset.get_configs("config")["seconds"]) + recordset = message.content + seconds = cast(int, recordset.configs_records["config"]["seconds"]) # Construct ReconnectIns and call _reconnect disconnect_msg, sleep_duration = _reconnect( ServerMessage.ReconnectIns(seconds=seconds) @@ -86,17 +85,8 @@ def handle_control_message(message: Message) -> Tuple[Optional[Message], int]: # Store DisconnectRes in recordset reason = cast(int, disconnect_msg.disconnect_res.reason) recordset = RecordSet() - recordset.set_configs("config", ConfigsRecord({"reason": reason})) - out_message = Message( - metadata=Metadata( - run_id=0, - task_id="", - group_id="", - ttl="", - task_type="reconnect", - ), - message=recordset, - ) + recordset.configs_records["config"] = ConfigsRecord({"reason": reason}) + out_message = message.create_reply(recordset, ttl="") # Return TaskRes and sleep duration return out_message, sleep_duration @@ -104,61 +94,61 @@ def handle_control_message(message: Message) -> Tuple[Optional[Message], int]: return None, 0 -def handle_legacy_message_from_tasktype( +def handle_legacy_message_from_msgtype( client_fn: ClientFn, message: Message, context: Context ) -> Message: """Handle legacy message in the inner most mod.""" - client = client_fn("-1") + client = client_fn(str(message.metadata.partition_id)) + + # Check if NumPyClient is returend + if isinstance(client, NumPyClient): + client = client.to_client() + log( + WARN, + "Deprecation Warning: The `client_fn` function must return an instance " + "of `Client`, but an instance of `NumpyClient` was returned. " + "Please use `NumPyClient.to_client()` method to convert it to `Client`.", + ) client.set_context(context) - task_type = message.metadata.task_type + message_type = message.metadata.message_type # Handle GetPropertiesIns - if task_type == TASK_TYPE_GET_PROPERTIES: + if message_type == MESSAGE_TYPE_GET_PROPERTIES: get_properties_res = maybe_call_get_properties( client=client, - get_properties_ins=recordset_to_getpropertiesins(message.message), + get_properties_ins=recordset_to_getpropertiesins(message.content), ) out_recordset = getpropertiesres_to_recordset(get_properties_res) # Handle GetParametersIns - elif task_type == TASK_TYPE_GET_PARAMETERS: + elif message_type == MESSAGE_TYPE_GET_PARAMETERS: get_parameters_res = maybe_call_get_parameters( client=client, - get_parameters_ins=recordset_to_getparametersins(message.message), + get_parameters_ins=recordset_to_getparametersins(message.content), ) out_recordset = getparametersres_to_recordset( get_parameters_res, keep_input=False ) # Handle FitIns - elif task_type == TASK_TYPE_FIT: + elif message_type == MESSAGE_TYPE_FIT: fit_res = maybe_call_fit( client=client, - fit_ins=recordset_to_fitins(message.message, keep_input=True), + fit_ins=recordset_to_fitins(message.content, keep_input=True), ) out_recordset = fitres_to_recordset(fit_res, keep_input=False) # Handle EvaluateIns - elif task_type == TASK_TYPE_EVALUATE: + elif message_type == MESSAGE_TYPE_EVALUATE: evaluate_res = maybe_call_evaluate( client=client, - evaluate_ins=recordset_to_evaluateins(message.message, keep_input=True), + evaluate_ins=recordset_to_evaluateins(message.content, keep_input=True), ) out_recordset = evaluateres_to_recordset(evaluate_res) else: - raise ValueError(f"Invalid task type: {task_type}") + raise ValueError(f"Invalid message type: {message_type}") # Return Message - out_message = Message( - metadata=Metadata( - run_id=0, # Non-user defined - task_id="", # Non-user defined - group_id="", # Non-user defined - ttl="", - task_type=task_type, - ), - message=out_recordset, - ) - return out_message + return message.create_reply(out_recordset, ttl="") def _reconnect( @@ -173,3 +163,20 @@ def _reconnect( # Build DisconnectRes message disconnect_res = ClientMessage.DisconnectRes(reason=reason) return ClientMessage(disconnect_res=disconnect_res), sleep_duration + + +def validate_out_message(out_message: Message, in_message_metadata: Metadata) -> bool: + """Validate the out message.""" + out_meta = out_message.metadata + in_meta = in_message_metadata + if ( # pylint: disable-next=too-many-boolean-expressions + out_meta.run_id == in_meta.run_id + and out_meta.message_id == "" # This will be generated by the server + and out_meta.src_node_id == in_meta.dst_node_id + and out_meta.dst_node_id == in_meta.src_node_id + and out_meta.reply_to_message == in_meta.message_id + and out_meta.group_id == in_meta.group_id + and out_meta.message_type == in_meta.message_type + ): + return True + return False diff --git a/src/py/flwr/client/message_handler/message_handler_test.py b/src/py/flwr/client/message_handler/message_handler_test.py index ad09ca95abc7..c24b51972f30 100644 --- a/src/py/flwr/client/message_handler/message_handler_test.py +++ b/src/py/flwr/client/message_handler/message_handler_test.py @@ -15,12 +15,16 @@ """Client-side message handler tests.""" +import unittest import uuid +from copy import copy +from typing import List from flwr.client import Client from flwr.client.typing import ClientFn from flwr.common import ( Code, + Context, EvaluateIns, EvaluateRes, FitIns, @@ -29,17 +33,17 @@ GetParametersRes, GetPropertiesIns, GetPropertiesRes, + Message, + Metadata, Parameters, + RecordSet, Status, ) from flwr.common import recordset_compat as compat from flwr.common import typing -from flwr.common.constant import TASK_TYPE_GET_PROPERTIES -from flwr.common.context import Context -from flwr.common.message import Message, Metadata -from flwr.common.recordset import RecordSet +from flwr.common.constant import MESSAGE_TYPE_GET_PROPERTIES -from .message_handler import handle_legacy_message_from_tasktype +from .message_handler import handle_legacy_message_from_msgtype, validate_out_message class ClientWithoutProps(Client): @@ -121,17 +125,20 @@ def test_client_without_get_properties() -> None: recordset = compat.getpropertiesins_to_recordset(GetPropertiesIns({})) message = Message( metadata=Metadata( - run_id=0, - task_id=str(uuid.uuid4()), - group_id="", + run_id=123, + message_id=str(uuid.uuid4()), + group_id="some group ID", + src_node_id=0, + dst_node_id=1123, + reply_to_message="", ttl="", - task_type=TASK_TYPE_GET_PROPERTIES, + message_type=MESSAGE_TYPE_GET_PROPERTIES, ), - message=recordset, + content=recordset, ) # Execute - actual_msg = handle_legacy_message_from_tasktype( + actual_msg = handle_legacy_message_from_msgtype( client_fn=_get_client_fn(client), message=message, context=Context(state=RecordSet()), @@ -146,10 +153,22 @@ def test_client_without_get_properties() -> None: properties={}, ) expected_rs = compat.getpropertiesres_to_recordset(expected_get_properties_res) - expected_msg = Message(message.metadata, expected_rs) + expected_msg = Message( + metadata=Metadata( + run_id=123, + message_id="", + group_id="some group ID", + src_node_id=1123, + dst_node_id=0, + reply_to_message=message.metadata.message_id, + ttl="", + message_type=MESSAGE_TYPE_GET_PROPERTIES, + ), + content=expected_rs, + ) - assert actual_msg.message == expected_msg.message - assert actual_msg.metadata.task_type == expected_msg.metadata.task_type + assert actual_msg.content == expected_msg.content + assert actual_msg.metadata == expected_msg.metadata def test_client_with_get_properties() -> None: @@ -159,17 +178,20 @@ def test_client_with_get_properties() -> None: recordset = compat.getpropertiesins_to_recordset(GetPropertiesIns({})) message = Message( metadata=Metadata( - run_id=0, - task_id=str(uuid.uuid4()), - group_id="", + run_id=123, + message_id=str(uuid.uuid4()), + group_id="some group ID", + src_node_id=0, + dst_node_id=1123, + reply_to_message="", ttl="", - task_type=TASK_TYPE_GET_PROPERTIES, + message_type=MESSAGE_TYPE_GET_PROPERTIES, ), - message=recordset, + content=recordset, ) # Execute - actual_msg = handle_legacy_message_from_tasktype( + actual_msg = handle_legacy_message_from_msgtype( client_fn=_get_client_fn(client), message=message, context=Context(state=RecordSet()), @@ -184,7 +206,85 @@ def test_client_with_get_properties() -> None: properties={"str_prop": "val", "int_prop": 1}, ) expected_rs = compat.getpropertiesres_to_recordset(expected_get_properties_res) - expected_msg = Message(message.metadata, expected_rs) + expected_msg = Message( + metadata=Metadata( + run_id=123, + message_id="", + group_id="some group ID", + src_node_id=1123, + dst_node_id=0, + reply_to_message=message.metadata.message_id, + ttl="", + message_type=MESSAGE_TYPE_GET_PROPERTIES, + ), + content=expected_rs, + ) + + assert actual_msg.content == expected_msg.content + assert actual_msg.metadata == expected_msg.metadata + + +class TestMessageValidation(unittest.TestCase): + """Test message validation.""" + + def setUp(self) -> None: + """Set up the message validation.""" + # Common setup for tests + self.in_metadata = Metadata( + run_id=123, + message_id="qwerty", + src_node_id=10, + dst_node_id=20, + reply_to_message="", + group_id="group1", + ttl="60", + message_type="mock", + ) + self.valid_out_metadata = Metadata( + run_id=123, + message_id="", + src_node_id=20, + dst_node_id=10, + reply_to_message="qwerty", + group_id="group1", + ttl="60", + message_type="mock", + ) + self.common_content = RecordSet() + + def test_valid_message(self) -> None: + """Test a valid message.""" + # Prepare + valid_message = Message(metadata=self.valid_out_metadata, content=RecordSet()) + + # Assert + self.assertTrue(validate_out_message(valid_message, self.in_metadata)) + + def test_invalid_message_run_id(self) -> None: + """Test invalid messages.""" + # Prepare + msg = Message(metadata=self.valid_out_metadata, content=RecordSet()) + + # Execute + invalid_metadata_list: List[Metadata] = [] + attrs = list(vars(self.valid_out_metadata).keys()) + for attr in attrs: + if attr == "_partition_id": + continue + if attr == "_ttl": # Skip configurable ttl + continue + # Make an invalid metadata + invalid_metadata = copy(self.valid_out_metadata) + value = getattr(invalid_metadata, attr) + if isinstance(value, int): + value = 999 + elif isinstance(value, str): + value = "999" + setattr(invalid_metadata, attr, value) + # Add to list + invalid_metadata_list.append(invalid_metadata) - assert actual_msg.message == expected_msg.message - assert actual_msg.metadata.task_type == expected_msg.metadata.task_type + # Assert + for invalid_metadata in invalid_metadata_list: + msg._metadata = invalid_metadata # pylint: disable=protected-access + self.assertFalse(validate_out_message(msg, self.in_metadata)) diff --git a/src/py/flwr/client/message_handler/task_handler.py b/src/py/flwr/client/message_handler/task_handler.py index daac1be77138..7f515a30fe5a 100644 --- a/src/py/flwr/client/message_handler/task_handler.py +++ b/src/py/flwr/client/message_handler/task_handler.py @@ -18,8 +18,7 @@ from typing import Optional from flwr.proto.fleet_pb2 import PullTaskInsResponse # pylint: disable=E0611 -from flwr.proto.node_pb2 import Node # pylint: disable=E0611 -from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 +from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611 def validate_task_ins(task_ins: TaskIns) -> bool: @@ -41,40 +40,6 @@ def validate_task_ins(task_ins: TaskIns) -> bool: return True -def validate_task_res(task_res: TaskRes) -> bool: - """Validate a TaskRes before filling its fields in the `send()` function. - - Parameters - ---------- - task_res: TaskRes - The task response to be sent to the server. - - Returns - ------- - is_valid: bool - True if the `task_id`, `group_id`, and `run_id` fields in TaskRes - and the `producer`, `consumer`, and `ancestry` fields in its sub-message Task - are not initialized accidentally elsewhere, - False otherwise. - """ - # Retrieve initialized fields in TaskRes and Task - initialized_fields_in_task_res = {field.name for field, _ in task_res.ListFields()} - initialized_fields_in_task = {field.name for field, _ in task_res.task.ListFields()} - - # Check if certain fields are already initialized - if ( # pylint: disable-next=too-many-boolean-expressions - "task_id" in initialized_fields_in_task_res - or "group_id" in initialized_fields_in_task_res - or "run_id" in initialized_fields_in_task_res - or "producer" in initialized_fields_in_task - or "consumer" in initialized_fields_in_task - or "ancestry" in initialized_fields_in_task - ): - return False - - return True - - def get_task_ins( pull_task_ins_response: PullTaskInsResponse, ) -> Optional[TaskIns]: @@ -87,35 +52,3 @@ def get_task_ins( task_ins: TaskIns = pull_task_ins_response.task_ins_list[0] return task_ins - - -def configure_task_res( - task_res: TaskRes, ref_task_ins: TaskIns, producer: Node -) -> TaskRes: - """Set the metadata of a TaskRes. - - Fill `group_id` and `run_id` in TaskRes - and `producer`, `consumer`, and `ancestry` in Task in TaskRes. - - `producer` in Task in TaskRes will remain unchanged/unset. - - Note that protobuf API `protobuf.message.MergeFrom(other_msg)` - does NOT always overwrite fields that are set in `other_msg`. - Please refer to: - https://googleapis.dev/python/protobuf/latest/google/protobuf/message.html - """ - task_res = TaskRes( - task_id="", # This will be generated by the server - group_id=ref_task_ins.group_id, - run_id=ref_task_ins.run_id, - task=task_res.task, - ) - # pylint: disable-next=no-member - task_res.task.MergeFrom( - Task( - producer=producer, - consumer=ref_task_ins.task.producer, - ancestry=[ref_task_ins.task_id], - ) - ) - return task_res diff --git a/src/py/flwr/client/message_handler/task_handler_test.py b/src/py/flwr/client/message_handler/task_handler_test.py index 9a668231d509..c8b9e14737ff 100644 --- a/src/py/flwr/client/message_handler/task_handler_test.py +++ b/src/py/flwr/client/message_handler/task_handler_test.py @@ -15,15 +15,10 @@ """Tests for module task_handler.""" -from flwr.client.message_handler.task_handler import ( - get_task_ins, - validate_task_ins, - validate_task_res, -) -from flwr.common import serde -from flwr.common.recordset import RecordSet +from flwr.client.message_handler.task_handler import get_task_ins, validate_task_ins +from flwr.common import RecordSet, serde from flwr.proto.fleet_pb2 import PullTaskInsResponse # pylint: disable=E0611 -from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 +from flwr.proto.task_pb2 import Task, TaskIns # pylint: disable=E0611 def test_validate_task_ins_no_task() -> None: @@ -47,38 +42,6 @@ def test_validate_task_ins_valid() -> None: assert validate_task_ins(task_ins) -def test_validate_task_res() -> None: - """Test validate_task_res.""" - task_res = TaskRes(task=Task()) - assert validate_task_res(task_res) - - task_res.task_id = "123" - assert not validate_task_res(task_res) - - task_res.Clear() - task_res.group_id = "123" - assert not validate_task_res(task_res) - - task_res.Clear() - task_res.run_id = 61016 - assert not validate_task_res(task_res) - - task_res.Clear() - # pylint: disable-next=no-member - task_res.task.producer.node_id = 0 - assert not validate_task_res(task_res) - - task_res.Clear() - # pylint: disable-next=no-member - task_res.task.consumer.node_id = 0 - assert not validate_task_res(task_res) - - task_res.Clear() - # pylint: disable-next=no-member - task_res.task.ancestry.append("123") - assert not validate_task_res(task_res) - - def test_get_task_ins_empty_response() -> None: """Test get_task_ins.""" res = PullTaskInsResponse(reconnect=None, task_ins_list=[]) diff --git a/src/py/flwr/client/mod/__init__.py b/src/py/flwr/client/mod/__init__.py index a181865614df..e06cabe995d7 100644 --- a/src/py/flwr/client/mod/__init__.py +++ b/src/py/flwr/client/mod/__init__.py @@ -15,10 +15,13 @@ """Mods.""" +from .centraldp_mods import adaptiveclipping_mod, fixedclipping_mod from .secure_aggregation.secaggplus_mod import secaggplus_mod from .utils import make_ffn __all__ = [ + "adaptiveclipping_mod", + "fixedclipping_mod", "make_ffn", "secaggplus_mod", ] diff --git a/src/py/flwr/client/mod/centraldp_mods.py b/src/py/flwr/client/mod/centraldp_mods.py new file mode 100644 index 000000000000..8a18e87c69f5 --- /dev/null +++ b/src/py/flwr/client/mod/centraldp_mods.py @@ -0,0 +1,137 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Clipping modifiers for central DP with client-side clipping.""" + + +from flwr.client.typing import ClientAppCallable +from flwr.common import ndarrays_to_parameters, parameters_to_ndarrays +from flwr.common import recordset_compat as compat +from flwr.common.constant import MESSAGE_TYPE_FIT +from flwr.common.context import Context +from flwr.common.differential_privacy import ( + compute_adaptive_clip_model_update, + compute_clip_model_update, +) +from flwr.common.differential_privacy_constants import KEY_CLIPPING_NORM, KEY_NORM_BIT +from flwr.common.message import Message + + +def fixedclipping_mod( + msg: Message, ctxt: Context, call_next: ClientAppCallable +) -> Message: + """Client-side fixed clipping modifier. + + This mod needs to be used with the DifferentialPrivacyClientSideFixedClipping + server-side strategy wrapper. + + The wrapper sends the clipping_norm value to the client. + + This mod clips the client model updates before sending them to the server. + + It operates on messages with type MESSAGE_TYPE_FIT. + + Notes + ----- + Consider the order of mods when using multiple. + + Typically, fixedclipping_mod should be the last to operate on params. + """ + if msg.metadata.message_type != MESSAGE_TYPE_FIT: + return call_next(msg, ctxt) + fit_ins = compat.recordset_to_fitins(msg.content, keep_input=True) + if KEY_CLIPPING_NORM not in fit_ins.config: + raise KeyError( + f"The {KEY_CLIPPING_NORM} value is not supplied by the " + f"DifferentialPrivacyClientSideFixedClipping wrapper at" + f" the server side." + ) + + clipping_norm = float(fit_ins.config[KEY_CLIPPING_NORM]) + server_to_client_params = parameters_to_ndarrays(fit_ins.parameters) + + # Call inner app + out_msg = call_next(msg, ctxt) + fit_res = compat.recordset_to_fitres(out_msg.content, keep_input=True) + + client_to_server_params = parameters_to_ndarrays(fit_res.parameters) + + # Clip the client update + compute_clip_model_update( + client_to_server_params, + server_to_client_params, + clipping_norm, + ) + + fit_res.parameters = ndarrays_to_parameters(client_to_server_params) + out_msg.content = compat.fitres_to_recordset(fit_res, keep_input=True) + return out_msg + + +def adaptiveclipping_mod( + msg: Message, ctxt: Context, call_next: ClientAppCallable +) -> Message: + """Client-side adaptive clipping modifier. + + This mod needs to be used with the DifferentialPrivacyClientSideAdaptiveClipping + server-side strategy wrapper. + + The wrapper sends the clipping_norm value to the client. + + This mod clips the client model updates before sending them to the server. + + It also sends KEY_NORM_BIT to the server for computing the new clipping value. + + It operates on messages with type MESSAGE_TYPE_FIT. + + Notes + ----- + Consider the order of mods when using multiple. + + Typically, adaptiveclipping_mod should be the last to operate on params. + """ + if msg.metadata.message_type != MESSAGE_TYPE_FIT: + return call_next(msg, ctxt) + + fit_ins = compat.recordset_to_fitins(msg.content, keep_input=True) + + if KEY_CLIPPING_NORM not in fit_ins.config: + raise KeyError( + f"The {KEY_CLIPPING_NORM} value is not supplied by the " + f"DifferentialPrivacyClientSideFixedClipping wrapper at" + f" the server side." + ) + if not isinstance(fit_ins.config[KEY_CLIPPING_NORM], float): + raise ValueError(f"{KEY_CLIPPING_NORM} should be a float value.") + clipping_norm = float(fit_ins.config[KEY_CLIPPING_NORM]) + server_to_client_params = parameters_to_ndarrays(fit_ins.parameters) + + # Call inner app + out_msg = call_next(msg, ctxt) + fit_res = compat.recordset_to_fitres(out_msg.content, keep_input=True) + + client_to_server_params = parameters_to_ndarrays(fit_res.parameters) + + # Clip the client update + norm_bit = compute_adaptive_clip_model_update( + client_to_server_params, + server_to_client_params, + clipping_norm, + ) + + fit_res.parameters = ndarrays_to_parameters(client_to_server_params) + + fit_res.metrics[KEY_NORM_BIT] = norm_bit + out_msg.content = compat.fitres_to_recordset(fit_res, keep_input=True) + return out_msg diff --git a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py index 850e02a2b5f9..fcb30c1eb0da 100644 --- a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py +++ b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py @@ -21,14 +21,17 @@ from typing import Any, Callable, Dict, List, Tuple, cast from flwr.client.typing import ClientAppCallable -from flwr.common import ndarray_to_bytes, parameters_to_ndarrays +from flwr.common import ( + ConfigsRecord, + Context, + Message, + RecordSet, + ndarray_to_bytes, + parameters_to_ndarrays, +) from flwr.common import recordset_compat as compat -from flwr.common.configsrecord import ConfigsRecord -from flwr.common.constant import TASK_TYPE_FIT -from flwr.common.context import Context +from flwr.common.constant import MESSAGE_TYPE_FIT from flwr.common.logger import log -from flwr.common.message import Message, Metadata -from flwr.common.recordset import RecordSet from flwr.common.secure_aggregation.crypto.shamir import create_shares from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( bytes_to_private_key, @@ -156,7 +159,7 @@ def _get_fit_fn( def fit() -> FitRes: out_msg = call_next(msg, ctxt) - return compat.recordset_to_fitres(out_msg.message, keep_input=False) + return compat.recordset_to_fitres(out_msg.content, keep_input=False) return fit @@ -168,17 +171,17 @@ def secaggplus_mod( ) -> Message: """Handle incoming message and return results, following the SecAgg+ protocol.""" # Ignore non-fit messages - if msg.metadata.task_type != TASK_TYPE_FIT: + if msg.metadata.message_type != MESSAGE_TYPE_FIT: return call_next(msg, ctxt) # Retrieve local state - if RECORD_KEY_STATE not in ctxt.state.configs: - ctxt.state.set_configs(RECORD_KEY_STATE, ConfigsRecord({})) - state_dict = ctxt.state.get_configs(RECORD_KEY_STATE).data + if RECORD_KEY_STATE not in ctxt.state.configs_records: + ctxt.state.configs_records[RECORD_KEY_STATE] = ConfigsRecord({}) + state_dict = ctxt.state.configs_records[RECORD_KEY_STATE] state = SecAggPlusState(**state_dict) # Retrieve incoming configs - configs = msg.message.get_configs(RECORD_KEY_CONFIGS).data + configs = msg.content.configs_records[RECORD_KEY_CONFIGS] # Check the validity of the next stage check_stage(state.current_stage, configs) @@ -203,16 +206,14 @@ def secaggplus_mod( raise ValueError(f"Unknown secagg stage: {state.current_stage}") # Save state - ctxt.state.set_configs(RECORD_KEY_STATE, ConfigsRecord(state.to_dict())) + ctxt.state.configs_records[RECORD_KEY_STATE] = ConfigsRecord(state.to_dict()) # Return message - return Message( - metadata=Metadata(0, "", "", "", TASK_TYPE_FIT), - message=RecordSet(configs={RECORD_KEY_CONFIGS: ConfigsRecord(res, False)}), - ) + content = RecordSet(configs_records={RECORD_KEY_CONFIGS: ConfigsRecord(res, False)}) + return msg.create_reply(content, ttl="") -def check_stage(current_stage: str, configs: Dict[str, ConfigsRecordValues]) -> None: +def check_stage(current_stage: str, configs: ConfigsRecord) -> None: """Check the validity of the next stage.""" # Check the existence of KEY_STAGE if KEY_STAGE not in configs: @@ -244,7 +245,7 @@ def check_stage(current_stage: str, configs: Dict[str, ConfigsRecordValues]) -> # pylint: disable-next=too-many-branches -def check_configs(stage: str, configs: Dict[str, ConfigsRecordValues]) -> None: +def check_configs(stage: str, configs: ConfigsRecord) -> None: """Check the validity of the configs.""" # Check `named_values` for the setup stage if stage == STAGE_SETUP: @@ -335,7 +336,7 @@ def check_configs(stage: str, configs: Dict[str, ConfigsRecordValues]) -> None: def _setup( - state: SecAggPlusState, configs: Dict[str, ConfigsRecordValues] + state: SecAggPlusState, configs: ConfigsRecord ) -> Dict[str, ConfigsRecordValues]: # Assigning parameter values to object fields sec_agg_param_dict = configs @@ -371,7 +372,7 @@ def _setup( # pylint: disable-next=too-many-locals def _share_keys( - state: SecAggPlusState, configs: Dict[str, ConfigsRecordValues] + state: SecAggPlusState, configs: ConfigsRecord ) -> Dict[str, ConfigsRecordValues]: named_bytes_tuples = cast(Dict[str, Tuple[bytes, bytes]], configs) key_dict = {int(sid): (pk1, pk2) for sid, (pk1, pk2) in named_bytes_tuples.items()} @@ -434,7 +435,7 @@ def _share_keys( # pylint: disable-next=too-many-locals def _collect_masked_input( state: SecAggPlusState, - configs: Dict[str, ConfigsRecordValues], + configs: ConfigsRecord, fit: Callable[[], FitRes], ) -> Dict[str, ConfigsRecordValues]: log(INFO, "Client %d: starting stage 2...", state.sid) @@ -509,7 +510,7 @@ def _collect_masked_input( def _unmask( - state: SecAggPlusState, configs: Dict[str, ConfigsRecordValues] + state: SecAggPlusState, configs: ConfigsRecord ) -> Dict[str, ConfigsRecordValues]: log(INFO, "Client %d: starting stage 3...", state.sid) diff --git a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py index 80d607318651..0a181f2ea8f5 100644 --- a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py +++ b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py @@ -19,11 +19,8 @@ from typing import Callable, Dict, List from flwr.client.mod import make_ffn -from flwr.common.configsrecord import ConfigsRecord -from flwr.common.constant import TASK_TYPE_FIT -from flwr.common.context import Context -from flwr.common.message import Message, Metadata -from flwr.common.recordset import RecordSet +from flwr.common import ConfigsRecord, Context, Message, Metadata, RecordSet +from flwr.common.constant import MESSAGE_TYPE_FIT from flwr.common.secure_aggregation.secaggplus_constants import ( KEY_ACTIVE_SECURE_ID_LIST, KEY_CIPHERTEXT_LIST, @@ -52,41 +49,49 @@ def get_test_handler( ctxt: Context, -) -> Callable[[Dict[str, ConfigsRecordValues]], Dict[str, ConfigsRecordValues]]: +) -> Callable[[Dict[str, ConfigsRecordValues]], ConfigsRecord]: """.""" - def empty_ffn(_: Message, _2: Context) -> Message: - return Message( - metadata=Metadata(0, "", "", "", TASK_TYPE_FIT), - message=RecordSet(), - ) + def empty_ffn(_msg: Message, _2: Context) -> Message: + return _msg.create_reply(RecordSet(), ttl="") app = make_ffn(empty_ffn, [secaggplus_mod]) - def func(configs: Dict[str, ConfigsRecordValues]) -> Dict[str, ConfigsRecordValues]: + def func(configs: Dict[str, ConfigsRecordValues]) -> ConfigsRecord: in_msg = Message( - metadata=Metadata(0, "", "", "", TASK_TYPE_FIT), - message=RecordSet(configs={RECORD_KEY_CONFIGS: ConfigsRecord(configs)}), + metadata=Metadata( + run_id=0, + message_id="", + src_node_id=0, + dst_node_id=123, + reply_to_message="", + group_id="", + ttl="", + message_type=MESSAGE_TYPE_FIT, + ), + content=RecordSet( + configs_records={RECORD_KEY_CONFIGS: ConfigsRecord(configs)} + ), ) out_msg = app(in_msg, ctxt) - return out_msg.message.get_configs(RECORD_KEY_CONFIGS).data + return out_msg.content.configs_records[RECORD_KEY_CONFIGS] return func def _make_ctxt() -> Context: cfg = ConfigsRecord(SecAggPlusState().to_dict()) - return Context(RecordSet(configs={RECORD_KEY_STATE: cfg})) + return Context(RecordSet(configs_records={RECORD_KEY_STATE: cfg})) def _make_set_state_fn( ctxt: Context, ) -> Callable[[str], None]: def set_stage(stage: str) -> None: - state_dict = ctxt.state.get_configs(RECORD_KEY_STATE).data + state_dict = ctxt.state.configs_records[RECORD_KEY_STATE] state = SecAggPlusState(**state_dict) state.current_stage = stage - ctxt.state.set_configs(RECORD_KEY_STATE, ConfigsRecord(state.to_dict())) + ctxt.state.configs_records[RECORD_KEY_STATE] = ConfigsRecord(state.to_dict()) return set_stage @@ -170,7 +175,7 @@ def test_stage_setup_check(self) -> None: # Test valid `named_values` try: - check_configs(STAGE_SETUP, valid_configs.copy()) + check_configs(STAGE_SETUP, ConfigsRecord(valid_configs)) # pylint: disable-next=broad-except except Exception as exc: self.fail(f"check_named_values() raised {type(exc)} unexpectedly!") @@ -213,7 +218,7 @@ def test_stage_share_keys_check(self) -> None: # Test valid `named_values` try: - check_configs(STAGE_SHARE_KEYS, valid_configs.copy()) + check_configs(STAGE_SHARE_KEYS, ConfigsRecord(valid_configs)) # pylint: disable-next=broad-except except Exception as exc: self.fail(f"check_named_values() raised {type(exc)} unexpectedly!") @@ -249,7 +254,7 @@ def test_stage_collect_masked_input_check(self) -> None: # Test valid `named_values` try: - check_configs(STAGE_COLLECT_MASKED_INPUT, valid_configs.copy()) + check_configs(STAGE_COLLECT_MASKED_INPUT, ConfigsRecord(valid_configs)) # pylint: disable-next=broad-except except Exception as exc: self.fail(f"check_named_values() raised {type(exc)} unexpectedly!") @@ -293,7 +298,7 @@ def test_stage_unmask_check(self) -> None: # Test valid `named_values` try: - check_configs(STAGE_UNMASK, valid_configs.copy()) + check_configs(STAGE_UNMASK, ConfigsRecord(valid_configs)) # pylint: disable-next=broad-except except Exception as exc: self.fail(f"check_named_values() raised {type(exc)} unexpectedly!") diff --git a/src/py/flwr/client/mod/utils.py b/src/py/flwr/client/mod/utils.py index 3db5da563c23..4c3c32944f01 100644 --- a/src/py/flwr/client/mod/utils.py +++ b/src/py/flwr/client/mod/utils.py @@ -18,8 +18,7 @@ from typing import List from flwr.client.typing import ClientAppCallable, Mod -from flwr.common.context import Context -from flwr.common.message import Message +from flwr.common import Context, Message def make_ffn(ffn: ClientAppCallable, mods: List[Mod]) -> ClientAppCallable: diff --git a/src/py/flwr/client/mod/utils_test.py b/src/py/flwr/client/mod/utils_test.py index 782ca8b0f97e..e588b8b53b3b 100644 --- a/src/py/flwr/client/mod/utils_test.py +++ b/src/py/flwr/client/mod/utils_test.py @@ -16,14 +16,17 @@ import unittest -from typing import List +from typing import List, cast from flwr.client.typing import ClientAppCallable, Mod -from flwr.common.configsrecord import ConfigsRecord -from flwr.common.context import Context -from flwr.common.message import Message, Metadata -from flwr.common.metricsrecord import MetricsRecord -from flwr.common.recordset import RecordSet +from flwr.common import ( + ConfigsRecord, + Context, + Message, + Metadata, + MetricsRecord, + RecordSet, +) from .utils import make_ffn @@ -33,10 +36,10 @@ def _increment_context_counter(context: Context) -> None: # Read from context - current_counter: int = context.state.get_metrics(METRIC)[COUNTER] # type: ignore + current_counter = cast(int, context.state.metrics_records[METRIC][COUNTER]) # update and override context current_counter += 1 - context.state.set_metrics(METRIC, record=MetricsRecord({COUNTER: current_counter})) + context.state.metrics_records[METRIC] = MetricsRecord({COUNTER: current_counter}) def make_mock_mod(name: str, footprint: List[str]) -> Mod: @@ -45,13 +48,13 @@ def make_mock_mod(name: str, footprint: List[str]) -> Mod: def mod(message: Message, context: Context, app: ClientAppCallable) -> Message: footprint.append(name) # add empty ConfigRecord to in_message for this mod - message.message.set_configs(name=name, record=ConfigsRecord()) + message.content.configs_records[name] = ConfigsRecord() _increment_context_counter(context) out_message: Message = app(message, context) footprint.append(name) _increment_context_counter(context) # add empty ConfigRegcord to out_message for this mod - out_message.message.set_configs(name=name, record=ConfigsRecord()) + out_message.content.configs_records[name] = ConfigsRecord() return out_message return mod @@ -62,9 +65,9 @@ def make_mock_app(name: str, footprint: List[str]) -> ClientAppCallable: def app(message: Message, context: Context) -> Message: footprint.append(name) - message.message.set_configs(name=name, record=ConfigsRecord()) - out_message = Message(metadata=message.metadata, message=RecordSet()) - out_message.message.set_configs(name=name, record=ConfigsRecord()) + message.content.configs_records[name] = ConfigsRecord() + out_message = Message(metadata=message.metadata, content=RecordSet()) + out_message.content.configs_records[name] = ConfigsRecord() print(context) return out_message @@ -73,8 +76,17 @@ def app(message: Message, context: Context) -> Message: def _get_dummy_flower_message() -> Message: return Message( - message=RecordSet(), - metadata=Metadata(run_id=0, task_id="", group_id="", ttl="", task_type="mock"), + content=RecordSet(), + metadata=Metadata( + run_id=0, + message_id="", + group_id="", + src_node_id=0, + dst_node_id=0, + reply_to_message="", + ttl="", + message_type="mock", + ), ) @@ -90,7 +102,7 @@ def test_multiple_mods(self) -> None: mock_mods = [make_mock_mod(name, footprint) for name in mock_mod_names] state = RecordSet() - state.set_metrics(METRIC, record=MetricsRecord({COUNTER: 0.0})) + state.metrics_records[METRIC] = MetricsRecord({COUNTER: 0.0}) context = Context(state=state) message = _get_dummy_flower_message() @@ -102,11 +114,14 @@ def test_multiple_mods(self) -> None: trace = mock_mod_names + ["app"] self.assertEqual(footprint, trace + list(reversed(mock_mod_names))) # pylint: disable-next=no-member - self.assertEqual("".join(message.message.configs.keys()), "".join(trace)) self.assertEqual( - "".join(out_message.message.configs.keys()), "".join(reversed(trace)) + "".join(message.content.configs_records.keys()), "".join(trace) ) - self.assertEqual(state.get_metrics(METRIC)[COUNTER], 2 * len(mock_mods)) + self.assertEqual( + "".join(out_message.content.configs_records.keys()), + "".join(reversed(trace)), + ) + self.assertEqual(state.metrics_records[METRIC][COUNTER], 2 * len(mock_mods)) def test_filter(self) -> None: """Test if a mod can filter incoming TaskIns.""" @@ -122,9 +137,9 @@ def filter_mod( _2: ClientAppCallable, ) -> Message: footprint.append("filter") - message.message.set_configs(name="filter", record=ConfigsRecord()) - out_message = Message(metadata=message.metadata, message=RecordSet()) - out_message.message.set_configs(name="filter", record=ConfigsRecord()) + message.content.configs_records["filter"] = ConfigsRecord() + out_message = Message(metadata=message.metadata, content=RecordSet()) + out_message.content.configs_records["filter"] = ConfigsRecord() # Skip calling app return out_message @@ -135,5 +150,5 @@ def filter_mod( # Assert self.assertEqual(footprint, ["filter"]) # pylint: disable-next=no-member - self.assertEqual(list(message.message.configs.keys())[0], "filter") - self.assertEqual(list(out_message.message.configs.keys())[0], "filter") + self.assertEqual(list(message.content.configs_records.keys())[0], "filter") + self.assertEqual(list(out_message.content.configs_records.keys())[0], "filter") diff --git a/src/py/flwr/client/node_state.py b/src/py/flwr/client/node_state.py index 465bbd356c1c..71681b783419 100644 --- a/src/py/flwr/client/node_state.py +++ b/src/py/flwr/client/node_state.py @@ -17,8 +17,7 @@ from typing import Any, Dict -from flwr.common.context import Context -from flwr.common.recordset import RecordSet +from flwr.common import Context, RecordSet class NodeState: diff --git a/src/py/flwr/client/node_state_tests.py b/src/py/flwr/client/node_state_tests.py index 11e5e74a31ec..193f52661579 100644 --- a/src/py/flwr/client/node_state_tests.py +++ b/src/py/flwr/client/node_state_tests.py @@ -15,21 +15,20 @@ """Node state tests.""" +from typing import cast + from flwr.client.node_state import NodeState -from flwr.common.configsrecord import ConfigsRecord -from flwr.common.context import Context +from flwr.common import ConfigsRecord, Context from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611 def _run_dummy_task(context: Context) -> Context: counter_value: str = "1" - if "counter" in context.state.configs.keys(): - counter_value = context.get_configs("counter")["count"] # type: ignore + if "counter" in context.state.configs_records.keys(): + counter_value = cast(str, context.state.configs_records["counter"]["count"]) counter_value += "1" - context.state.set_configs( - name="counter", record=ConfigsRecord({"count": counter_value}) - ) + context.state.configs_records["counter"] = ConfigsRecord({"count": counter_value}) return context @@ -61,4 +60,6 @@ def test_multirun_in_node_state() -> None: # Verify values for run_id, context in node_state.run_contexts.items(): - assert context.state.get_configs("counter")["count"] == expected_values[run_id] + assert ( + context.state.configs_records["counter"]["count"] == expected_values[run_id] + ) diff --git a/src/py/flwr/client/numpy_client.py b/src/py/flwr/client/numpy_client.py index a77889912a09..0247958d88a9 100644 --- a/src/py/flwr/client/numpy_client.py +++ b/src/py/flwr/client/numpy_client.py @@ -21,12 +21,12 @@ from flwr.client.client import Client from flwr.common import ( Config, + Context, NDArrays, Scalar, ndarrays_to_parameters, parameters_to_ndarrays, ) -from flwr.common.context import Context from flwr.common.typing import ( Code, EvaluateIns, diff --git a/src/py/flwr/client/rest_client/connection.py b/src/py/flwr/client/rest_client/connection.py index a5c8ea0957d2..c637475551ed 100644 --- a/src/py/flwr/client/rest_client/connection.py +++ b/src/py/flwr/client/rest_client/connection.py @@ -17,19 +17,16 @@ import sys from contextlib import contextmanager +from copy import copy from logging import ERROR, INFO, WARN from typing import Callable, Dict, Iterator, Optional, Tuple, Union, cast -from flwr.client.message_handler.task_handler import ( - configure_task_res, - get_task_ins, - validate_task_ins, - validate_task_res, -) +from flwr.client.message_handler.message_handler import validate_out_message +from flwr.client.message_handler.task_handler import get_task_ins, validate_task_ins from flwr.common import GRPC_MAX_MESSAGE_LENGTH from flwr.common.constant import MISSING_EXTRA_REST from flwr.common.logger import log -from flwr.common.message import Message +from flwr.common.message import Message, Metadata from flwr.common.serde import message_from_taskins, message_to_taskres from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, @@ -50,7 +47,7 @@ KEY_NODE = "node" -KEY_TASK_INS = "current_task_ins" +KEY_METADATA = "in_message_metadata" PATH_CREATE_NODE: str = "api/v0/fleet/create-node" @@ -122,8 +119,8 @@ def http_request_response( "must be provided as a string path to the client.", ) - # Necessary state to link TaskRes to TaskIns - state: Dict[str, Optional[TaskIns]] = {KEY_TASK_INS: None} + # Necessary state to validate messages to be sent + state: Dict[str, Optional[Metadata]] = {KEY_METADATA: None} # Enable create_node and delete_node to store node node_store: Dict[str, Optional[Node]] = {KEY_NODE: None} @@ -258,16 +255,18 @@ def receive() -> Optional[Message]: task_ins: Optional[TaskIns] = get_task_ins(pull_task_ins_response_proto) # Discard the current TaskIns if not valid - if task_ins is not None and not validate_task_ins(task_ins): + if task_ins is not None and not ( + task_ins.task.consumer.node_id == node.node_id + and validate_task_ins(task_ins) + ): task_ins = None - # Remember `task_ins` until `task_res` is available - state[KEY_TASK_INS] = task_ins - # Return the Message if available message = None + state[KEY_METADATA] = None if task_ins is not None: message = message_from_taskins(task_ins) + state[KEY_METADATA] = copy(message.metadata) log(INFO, "[Node] POST /%s: success", PATH_PULL_TASK_INS) return message @@ -277,25 +276,21 @@ def send(message: Message) -> None: if node_store[KEY_NODE] is None: log(ERROR, "Node instance missing") return - node: Node = cast(Node, node_store[KEY_NODE]) - if state[KEY_TASK_INS] is None: - log(ERROR, "No current TaskIns") + # Get incoming message + in_metadata = state[KEY_METADATA] + if in_metadata is None: + log(ERROR, "No current message") return - task_ins: TaskIns = cast(TaskIns, state[KEY_TASK_INS]) + # Validate out message + if not validate_out_message(message, in_metadata): + log(ERROR, "Invalid out message") + return # Construct TaskRes task_res = message_to_taskres(message) - # Check if fields to be set are not initialized - if not validate_task_res(task_res): - state[KEY_TASK_INS] = None - log(ERROR, "TaskRes has been initialized accidentally") - - # Configure TaskRes - task_res = configure_task_res(task_res, task_ins, node) - # Serialize ProtoBuf to bytes push_task_res_request_proto = PushTaskResRequest(task_res_list=[task_res]) push_task_res_request_bytes: bytes = ( @@ -314,7 +309,7 @@ def send(message: Message) -> None: timeout=None, ) - state[KEY_TASK_INS] = None + state[KEY_METADATA] = None # Check status code and headers if res.status_code != 200: diff --git a/src/py/flwr/client/typing.py b/src/py/flwr/client/typing.py index 7aef2b30e0fc..956ac7a15c05 100644 --- a/src/py/flwr/client/typing.py +++ b/src/py/flwr/client/typing.py @@ -17,8 +17,7 @@ from typing import Callable -from flwr.common.context import Context -from flwr.common.message import Message +from flwr.common import Context, Message from .client import Client as Client diff --git a/src/py/flwr/common/__init__.py b/src/py/flwr/common/__init__.py index 2f45de45dfc3..319c15900217 100644 --- a/src/py/flwr/common/__init__.py +++ b/src/py/flwr/common/__init__.py @@ -15,14 +15,23 @@ """Common components shared between server and client.""" +from .context import Context as Context from .date import now as now from .grpc import GRPC_MAX_MESSAGE_LENGTH from .logger import configure as configure from .logger import log as log +from .message import Message as Message +from .message import Metadata as Metadata from .parameter import bytes_to_ndarray as bytes_to_ndarray from .parameter import ndarray_to_bytes as ndarray_to_bytes from .parameter import ndarrays_to_parameters as ndarrays_to_parameters from .parameter import parameters_to_ndarrays as parameters_to_ndarrays +from .record import Array as Array +from .record import ConfigsRecord as ConfigsRecord +from .record import MetricsRecord as MetricsRecord +from .record import ParametersRecord as ParametersRecord +from .record import RecordSet as RecordSet +from .record import array_from_numpy as array_from_numpy from .telemetry import EventType as EventType from .telemetry import event as event from .typing import ClientMessage as ClientMessage @@ -49,11 +58,15 @@ from .typing import Status as Status __all__ = [ + "Array", + "array_from_numpy", "bytes_to_ndarray", "ClientMessage", "Code", "Config", + "ConfigsRecord", "configure", + "Context", "DisconnectRes", "EvaluateIns", "EvaluateRes", @@ -67,8 +80,11 @@ "GetPropertiesRes", "GRPC_MAX_MESSAGE_LENGTH", "log", + "Message", + "Metadata", "Metrics", "MetricsAggregationFn", + "MetricsRecord", "ndarray_to_bytes", "now", "NDArray", @@ -76,8 +92,10 @@ "ndarrays_to_parameters", "Parameters", "parameters_to_ndarrays", + "ParametersRecord", "Properties", "ReconnectIns", + "RecordSet", "Scalar", "ServerMessage", "Status", diff --git a/src/py/flwr/common/address_test.py b/src/py/flwr/common/address_test.py index c12dd5fd289e..420b89871d69 100644 --- a/src/py/flwr/common/address_test.py +++ b/src/py/flwr/common/address_test.py @@ -109,10 +109,10 @@ def test_domain_correct() -> None: """Test if a correct domain address is correctly parsed.""" # Prepare addresses = [ - ("flower.dev:123", ("flower.dev", 123, None)), - ("sub.flower.dev:123", ("sub.flower.dev", 123, None)), - ("sub2.sub1.flower.dev:123", ("sub2.sub1.flower.dev", 123, None)), - ("s5.s4.s3.s2.s1.flower.dev:123", ("s5.s4.s3.s2.s1.flower.dev", 123, None)), + ("flower.ai:123", ("flower.ai", 123, None)), + ("sub.flower.ai:123", ("sub.flower.ai", 123, None)), + ("sub2.sub1.flower.ai:123", ("sub2.sub1.flower.ai", 123, None)), + ("s5.s4.s3.s2.s1.flower.ai:123", ("s5.s4.s3.s2.s1.flower.ai", 123, None)), ("localhost:123", ("localhost", 123, None)), ("https://localhost:123", ("https://localhost", 123, None)), ("http://localhost:123", ("http://localhost", 123, None)), @@ -130,8 +130,8 @@ def test_domain_incorrect() -> None: """Test if an incorrect domain address returns None.""" # Prepare addresses = [ - "flower.dev", - "flower.dev:65536", + "flower.ai", + "flower.ai:65536", ] for address in addresses: diff --git a/src/py/flwr/common/configsrecord.py b/src/py/flwr/common/configsrecord.py deleted file mode 100644 index b0480841e06c..000000000000 --- a/src/py/flwr/common/configsrecord.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright 2024 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""ConfigsRecord.""" - - -from dataclasses import dataclass, field -from typing import Dict, Optional, get_args - -from .typing import ConfigsRecordValues, ConfigsScalar - - -@dataclass -class ConfigsRecord: - """Configs record.""" - - data: Dict[str, ConfigsRecordValues] = field(default_factory=dict) - - def __init__( - self, - configs_dict: Optional[Dict[str, ConfigsRecordValues]] = None, - keep_input: bool = True, - ): - """Construct a ConfigsRecord object. - - Parameters - ---------- - configs_dict : Optional[Dict[str, ConfigsRecordValues]] - A dictionary that stores basic types (i.e. `str`, `int`, `float`, `bytes` as - defined in `ConfigsScalar`) and lists of such types (see - `ConfigsScalarList`). - keep_input : bool (default: True) - A boolean indicating whether config passed should be deleted from the input - dictionary immediately after adding them to the record. When set - to True, the data is duplicated in memory. If memory is a concern, set - it to False. - """ - self.data = {} - if configs_dict: - self.set_configs(configs_dict, keep_input=keep_input) - - def set_configs( - self, configs_dict: Dict[str, ConfigsRecordValues], keep_input: bool = True - ) -> None: - """Add configs to the record. - - Parameters - ---------- - configs_dict : Dict[str, ConfigsRecordValues] - A dictionary that stores basic types (i.e. `str`,`int`, `float`, `bytes` as - defined in `ConfigsRecordValues`) and list of such types (see - `ConfigsScalarList`). - keep_input : bool (default: True) - A boolean indicating whether config passed should be deleted from the input - dictionary immediately after adding them to the record. When set - to True, the data is duplicated in memory. If memory is a concern, set - it to False. - """ - if any(not isinstance(k, str) for k in configs_dict.keys()): - raise TypeError(f"Not all keys are of valid type. Expected {str}") - - def is_valid(value: ConfigsScalar) -> None: - """Check if value is of expected type.""" - if not isinstance(value, get_args(ConfigsScalar)): - raise TypeError( - "Not all values are of valid type." - f" Expected {ConfigsRecordValues} but you passed {type(value)}." - ) - - # Check types of values - # Split between those values that are list and those that aren't - # then process in the same way - for value in configs_dict.values(): - if isinstance(value, list): - # If your lists are large (e.g. 1M+ elements) this will be slow - # 1s to check 10M element list on a M2 Pro - # In such settings, you'd be better of treating such config as - # an array and pass it to a ParametersRecord. - # Empty lists are valid - if len(value) > 0: - is_valid(value[0]) - # all elements in the list must be of the same valid type - # this is needed for protobuf - value_type = type(value[0]) - if not all(isinstance(v, value_type) for v in value): - raise TypeError( - "All values in a list must be of the same valid type. " - f"One of {ConfigsScalar}." - ) - else: - is_valid(value) - - # Add configs to record - if keep_input: - # Copy - self.data = configs_dict.copy() - else: - # Add entries to dataclass without duplicating memory - for key in list(configs_dict.keys()): - self.data[key] = configs_dict[key] - del configs_dict[key] - - def __getitem__(self, key: str) -> ConfigsRecordValues: - """Retrieve an element stored in record.""" - return self.data[key] diff --git a/src/py/flwr/common/constant.py b/src/py/flwr/common/constant.py index 8d1d865f084b..2946a594e68c 100644 --- a/src/py/flwr/common/constant.py +++ b/src/py/flwr/common/constant.py @@ -15,6 +15,8 @@ """Flower constants.""" +from __future__ import annotations + MISSING_EXTRA_REST = """ Extra dependencies required for using the REST-based Fleet API are missing. @@ -26,13 +28,25 @@ TRANSPORT_TYPE_GRPC_BIDI = "grpc-bidi" TRANSPORT_TYPE_GRPC_RERE = "grpc-rere" TRANSPORT_TYPE_REST = "rest" +TRANSPORT_TYPE_VCE = "vce" TRANSPORT_TYPES = [ TRANSPORT_TYPE_GRPC_BIDI, TRANSPORT_TYPE_GRPC_RERE, TRANSPORT_TYPE_REST, + TRANSPORT_TYPE_VCE, ] -TASK_TYPE_GET_PROPERTIES = "get_properties" -TASK_TYPE_GET_PARAMETERS = "get_parameters" -TASK_TYPE_FIT = "fit" -TASK_TYPE_EVALUATE = "evaluate" +MESSAGE_TYPE_GET_PROPERTIES = "get_properties" +MESSAGE_TYPE_GET_PARAMETERS = "get_parameters" +MESSAGE_TYPE_FIT = "fit" +MESSAGE_TYPE_EVALUATE = "evaluate" + + +class SType: + """Serialisation type.""" + + NUMPY = "numpy.ndarray" + + def __new__(cls) -> SType: + """Prevent instantiation.""" + raise TypeError(f"{cls.__name__} cannot be instantiated.") diff --git a/src/py/flwr/common/context.py b/src/py/flwr/common/context.py index 30c1131a206f..b6349307d150 100644 --- a/src/py/flwr/common/context.py +++ b/src/py/flwr/common/context.py @@ -17,7 +17,7 @@ from dataclasses import dataclass -from .recordset import RecordSet +from .record import RecordSet @dataclass diff --git a/src/py/flwr/common/differential_privacy.py b/src/py/flwr/common/differential_privacy.py new file mode 100644 index 000000000000..3b81fba44733 --- /dev/null +++ b/src/py/flwr/common/differential_privacy.py @@ -0,0 +1,164 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility functions for differential privacy.""" + + +from logging import WARNING +from typing import Optional, Tuple + +import numpy as np + +from flwr.common import ( + NDArrays, + Parameters, + ndarrays_to_parameters, + parameters_to_ndarrays, +) +from flwr.common.logger import log + + +def get_norm(input_arrays: NDArrays) -> float: + """Compute the L2 norm of the flattened input.""" + array_norms = [np.linalg.norm(array.flat) for array in input_arrays] + # pylint: disable=consider-using-generator + return float(np.sqrt(sum([norm**2 for norm in array_norms]))) + + +def add_gaussian_noise_inplace(input_arrays: NDArrays, std_dev: float) -> None: + """Add Gaussian noise to each element of the input arrays.""" + for array in input_arrays: + array += np.random.normal(0, std_dev, array.shape) + + +def clip_inputs_inplace(input_arrays: NDArrays, clipping_norm: float) -> None: + """Clip model update based on the clipping norm in-place. + + FlatClip method of the paper: https://arxiv.org/abs/1710.06963 + """ + input_norm = get_norm(input_arrays) + scaling_factor = min(1, clipping_norm / input_norm) + for array in input_arrays: + array *= scaling_factor + + +def compute_stdv( + noise_multiplier: float, clipping_norm: float, num_sampled_clients: int +) -> float: + """Compute standard deviation for noise addition. + + Paper: https://arxiv.org/abs/1710.06963 + """ + return float((noise_multiplier * clipping_norm) / num_sampled_clients) + + +def compute_clip_model_update( + param1: NDArrays, param2: NDArrays, clipping_norm: float +) -> None: + """Compute model update (param1 - param2) and clip it. + + Then add the clipped value to param1.""" + model_update = [np.subtract(x, y) for (x, y) in zip(param1, param2)] + clip_inputs_inplace(model_update, clipping_norm) + + for i, _ in enumerate(param2): + param1[i] = param2[i] + model_update[i] + + +def adaptive_clip_inputs_inplace(input_arrays: NDArrays, clipping_norm: float) -> bool: + """Clip model update based on the clipping norm in-place. + + It returns true if scaling_factor < 1 which is used for norm_bit + FlatClip method of the paper: https://arxiv.org/abs/1710.06963 + """ + input_norm = get_norm(input_arrays) + scaling_factor = min(1, clipping_norm / input_norm) + for array in input_arrays: + array *= scaling_factor + return scaling_factor < 1 + + +def compute_adaptive_clip_model_update( + param1: NDArrays, param2: NDArrays, clipping_norm: float +) -> bool: + """Compute model update, clip it, then add the clipped value to param1. + + model update = param1 - param2 + Return the norm_bit + """ + model_update = [np.subtract(x, y) for (x, y) in zip(param1, param2)] + norm_bit = adaptive_clip_inputs_inplace(model_update, clipping_norm) + + for i, _ in enumerate(param2): + param1[i] = param2[i] + model_update[i] + + return norm_bit + + +def add_gaussian_noise_to_params( + model_params: Parameters, + noise_multiplier: float, + clipping_norm: float, + num_sampled_clients: int, +) -> Parameters: + """Add gaussian noise to model parameters.""" + model_params_ndarrays = parameters_to_ndarrays(model_params) + add_gaussian_noise_inplace( + model_params_ndarrays, + compute_stdv(noise_multiplier, clipping_norm, num_sampled_clients), + ) + return ndarrays_to_parameters(model_params_ndarrays) + + +def compute_adaptive_noise_params( + noise_multiplier: float, + num_sampled_clients: float, + clipped_count_stddev: Optional[float], +) -> Tuple[float, float]: + """Compute noising parameters for the adaptive clipping. + + Paper: https://arxiv.org/abs/1905.03871 + """ + if noise_multiplier > 0: + if clipped_count_stddev is None: + clipped_count_stddev = num_sampled_clients / 20 + if noise_multiplier >= 2 * clipped_count_stddev: + raise ValueError( + f"If not specified, `clipped_count_stddev` is set to " + f"`num_sampled_clients`/20 by default. This value " + f"({num_sampled_clients / 20}) is too low to achieve the " + f"desired effective `noise_multiplier` ({noise_multiplier}). " + f"Consider increasing `clipped_count_stddev` or decreasing " + f"`noise_multiplier`." + ) + noise_multiplier_value = ( + noise_multiplier ** (-2) - (2 * clipped_count_stddev) ** (-2) + ) ** -0.5 + + adding_noise = noise_multiplier_value / noise_multiplier + if adding_noise >= 2: + log( + WARNING, + "A significant amount of noise (%s) has to be " + "added. Consider increasing `clipped_count_stddev` or " + "`num_sampled_clients`.", + adding_noise, + ) + + else: + if clipped_count_stddev is None: + clipped_count_stddev = 0.0 + noise_multiplier_value = 0.0 + + return clipped_count_stddev, noise_multiplier_value diff --git a/src/py/flwr/common/differential_privacy_constants.py b/src/py/flwr/common/differential_privacy_constants.py new file mode 100644 index 000000000000..415759dfdf0b --- /dev/null +++ b/src/py/flwr/common/differential_privacy_constants.py @@ -0,0 +1,25 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Constants for differential privacy.""" + + +KEY_CLIPPING_NORM = "clipping_norm" +KEY_NORM_BIT = "norm_bit" +CLIENTS_DISCREPANCY_WARNING = ( + "The number of clients returning parameters (%s)" + " differs from the number of sampled clients (%s)." + " This could impact the differential privacy guarantees," + " potentially leading to privacy leakage or inadequate noise calibration." +) diff --git a/src/py/flwr/common/differential_privacy_test.py b/src/py/flwr/common/differential_privacy_test.py new file mode 100644 index 000000000000..ea2e193df2f3 --- /dev/null +++ b/src/py/flwr/common/differential_privacy_test.py @@ -0,0 +1,209 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Differential Privacy (DP) utility functions tests.""" + + +import numpy as np + +from .differential_privacy import ( + add_gaussian_noise_inplace, + clip_inputs_inplace, + compute_adaptive_noise_params, + compute_clip_model_update, + compute_stdv, + get_norm, +) + + +def test_add_gaussian_noise_inplace() -> None: + """Test add_gaussian_noise_inplace function.""" + # Prepare + update = [np.array([[1.0, 2.0], [3.0, 4.0]]), np.array([[5.0, 6.0], [7.0, 8.0]])] + std_dev = 0.1 + + # Execute + add_gaussian_noise_inplace(update, std_dev) + + # Assert + # Check that the shape of the result is the same as the input + for layer in update: + assert layer.shape == (2, 2) + + # Check that the values have been changed and are not equal to the original update + for layer in update: + assert not np.array_equal( + layer, [[1.0, 2.0], [3.0, 4.0]] + ) and not np.array_equal(layer, [[5.0, 6.0], [7.0, 8.0]]) + + # Check that the noise has been added + for layer in update: + noise_added = ( + layer - np.array([[1.0, 2.0], [3.0, 4.0]]) + if np.array_equal(layer, [[1.0, 2.0], [3.0, 4.0]]) + else layer - np.array([[5.0, 6.0], [7.0, 8.0]]) + ) + assert np.any(np.abs(noise_added) > 0) + + +def test_get_norm() -> None: + """Test get_norm function.""" + # Prepare + update = [np.array([[1, 2], [3, 4]]), np.array([[5, 6], [7, 8]])] + + # Execute + result = get_norm(update) + + expected = float( + np.linalg.norm(np.concatenate([sub_update.flatten() for sub_update in update])) + ) + + # Assert + assert expected == result + + +def test_clip_inputs_inplace() -> None: + """Test clip_inputs_inplace function.""" + # Prepare + updates = [ + np.array([[1.5, -0.5], [2.0, -1.0]]), + np.array([0.5, -0.5]), + np.array([[-0.5, 1.5], [-1.0, 2.0]]), + np.array([-0.5, 0.5]), + ] + clipping_norm = 1.5 + + original_updates = [np.copy(update) for update in updates] + + # Execute + clip_inputs_inplace(updates, clipping_norm) + + # Assert + for updated, original_update in zip(updates, original_updates): + clip_norm = np.linalg.norm(original_update) + assert np.all(updated <= clip_norm) and np.all(updated >= -clip_norm) + + +def test_compute_stdv() -> None: + """Test compute_stdv function.""" + # Prepare + noise_multiplier = 1.0 + clipping_norm = 0.5 + num_sampled_clients = 10 + + # Execute + stdv = compute_stdv(noise_multiplier, clipping_norm, num_sampled_clients) + + # Assert + expected_stdv = float((noise_multiplier * clipping_norm) / num_sampled_clients) + assert stdv == expected_stdv + + +def test_compute_clip_model_update() -> None: + """Test compute_clip_model_update function.""" + # Prepare + param1 = [ + np.array([0.5, 1.5, 2.5]), + np.array([3.5, 4.5, 5.5]), + np.array([6.5, 7.5, 8.5]), + ] + param2 = [ + np.array([1.0, 2.0, 3.0]), + np.array([4.0, 5.0, 6.0]), + np.array([7.0, 8.0, 9.0]), + ] + clipping_norm = 4 + + expected_result = [ + np.array([0.5, 1.5, 2.5]), + np.array([3.5, 4.5, 5.5]), + np.array([6.5, 7.5, 8.5]), + ] + + # Execute + compute_clip_model_update(param1, param2, clipping_norm) + + # Assert + for i, param in enumerate(param1): + np.testing.assert_array_almost_equal(param, expected_result[i]) + + +def test_compute_adaptive_noise_params() -> None: + """Test compute_adaptive_noise_params function.""" + # Test valid input with positive noise_multiplier + noise_multiplier = 1.0 + num_sampled_clients = 100.0 + clipped_count_stddev = None + result = compute_adaptive_noise_params( + noise_multiplier, num_sampled_clients, clipped_count_stddev + ) + + # Assert + assert isinstance(result, tuple) + assert len(result) == 2 + assert result[0] > 0.0 + assert result[1] > 0.0 + + # Test valid input with zero noise_multiplier + noise_multiplier = 0.0 + num_sampled_clients = 50.0 + clipped_count_stddev = None + result = compute_adaptive_noise_params( + noise_multiplier, num_sampled_clients, clipped_count_stddev + ) + + # Assert + assert isinstance(result, tuple) + assert len(result) == 2 + assert result[0] == 0.0 + assert result[1] == 0.0 + + # Test valid input with specified clipped_count_stddev + noise_multiplier = 3.0 + num_sampled_clients = 200.0 + clipped_count_stddev = 5.0 + result = compute_adaptive_noise_params( + noise_multiplier, num_sampled_clients, clipped_count_stddev + ) + + # Assert + assert isinstance(result, tuple) + assert len(result) == 2 + assert result[0] == clipped_count_stddev + assert result[1] > 0.0 + + # Test invalid input with noise_multiplier >= 2 * clipped_count_stddev + noise_multiplier = 10.0 + num_sampled_clients = 100.0 + clipped_count_stddev = None + try: + compute_adaptive_noise_params( + noise_multiplier, num_sampled_clients, clipped_count_stddev + ) + except ValueError: + pass + else: + raise AssertionError("Expected ValueError not raised.") + + # Test intermediate calculation + noise_multiplier = 3.0 + num_sampled_clients = 200.0 + clipped_count_stddev = 5.0 + result = compute_adaptive_noise_params( + noise_multiplier, num_sampled_clients, clipped_count_stddev + ) + temp_value = (noise_multiplier ** (-2) - (2 * clipped_count_stddev) ** (-2)) ** -0.5 + + # Assert + assert np.isclose(result[1], temp_value, rtol=1e-6) diff --git a/src/py/flwr/common/exit_handlers.py b/src/py/flwr/common/exit_handlers.py new file mode 100644 index 000000000000..30750c28a450 --- /dev/null +++ b/src/py/flwr/common/exit_handlers.py @@ -0,0 +1,87 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Common function to register exit handlers for server and client.""" + + +import sys +from signal import SIGINT, SIGTERM, signal +from threading import Thread +from types import FrameType +from typing import List, Optional + +from grpc import Server + +from flwr.common.telemetry import EventType, event + + +def register_exit_handlers( + event_type: EventType, + grpc_servers: Optional[List[Server]] = None, + bckg_threads: Optional[List[Thread]] = None, +) -> None: + """Register exit handlers for `SIGINT` and `SIGTERM` signals. + + Parameters + ---------- + event_type : EventType + The telemetry event that should be logged before exit. + grpc_servers: Optional[List[Server]] (default: None) + An otpional list of gRPC servers that need to be gracefully + terminated before exiting. + bckg_threads: Optional[List[Thread]] (default: None) + An optional list of threads that need to be gracefully + terminated before exiting. + """ + default_handlers = { + SIGINT: None, + SIGTERM: None, + } + + def graceful_exit_handler( # type: ignore + signalnum, + frame: FrameType, # pylint: disable=unused-argument + ) -> None: + """Exit handler to be registered with `signal.signal`. + + When called will reset signal handler to original signal handler from + default_handlers. + """ + # Reset to default handler + signal(signalnum, default_handlers[signalnum]) + + event_res = event(event_type=event_type) + + if grpc_servers is not None: + for grpc_server in grpc_servers: + grpc_server.stop(grace=1) + + if bckg_threads is not None: + for bckg_thread in bckg_threads: + bckg_thread.join() + + # Ensure event has happend + event_res.result() + + # Setup things for graceful exit + sys.exit(0) + + default_handlers[SIGINT] = signal( # type: ignore + SIGINT, + graceful_exit_handler, # type: ignore + ) + default_handlers[SIGTERM] = signal( # type: ignore + SIGTERM, + graceful_exit_handler, # type: ignore + ) diff --git a/src/py/flwr/common/message.py b/src/py/flwr/common/message.py index f693d8e27bc3..1e1132e42e27 100644 --- a/src/py/flwr/common/message.py +++ b/src/py/flwr/common/message.py @@ -14,37 +14,144 @@ # ============================================================================== """Message.""" +from __future__ import annotations from dataclasses import dataclass -from .recordset import RecordSet +from .record import RecordSet @dataclass -class Metadata: - """A dataclass holding metadata associated with the current task. +class Metadata: # pylint: disable=too-many-instance-attributes + """A dataclass holding metadata associated with the current message. Parameters ---------- run_id : int An identifier for the current run. - task_id : str - An identifier for the current task. + message_id : str + An identifier for the current message. + src_node_id : int + An identifier for the node sending this message. + dst_node_id : int + An identifier for the node receiving this message. + reply_to_message : str + An identifier for the message this message replies to. group_id : str - An identifier for grouping tasks. In some settings + An identifier for grouping messages. In some settings, this is used as the FL round. ttl : str - Time-to-live for this task. - task_type : str + Time-to-live for this message. + message_type : str A string that encodes the action to be executed on the receiving end. + partition_id : Optional[int] + An identifier that can be used when loading a particular + data partition for a ClientApp. Making use of this identifier + is more relevant when conducting simulations. """ - run_id: int - task_id: str - group_id: str - ttl: str - task_type: str + _run_id: int + _message_id: str + _src_node_id: int + _dst_node_id: int + _reply_to_message: str + _group_id: str + _ttl: str + _message_type: str + _partition_id: int | None + + def __init__( # pylint: disable=too-many-arguments + self, + run_id: int, + message_id: str, + src_node_id: int, + dst_node_id: int, + reply_to_message: str, + group_id: str, + ttl: str, + message_type: str, + partition_id: int | None = None, + ) -> None: + self._run_id = run_id + self._message_id = message_id + self._src_node_id = src_node_id + self._dst_node_id = dst_node_id + self._reply_to_message = reply_to_message + self._group_id = group_id + self._ttl = ttl + self._message_type = message_type + self._partition_id = partition_id + + @property + def run_id(self) -> int: + """An identifier for the current run.""" + return self._run_id + + @property + def message_id(self) -> str: + """An identifier for the current message.""" + return self._message_id + + @property + def src_node_id(self) -> int: + """An identifier for the node sending this message.""" + return self._src_node_id + + @property + def reply_to_message(self) -> str: + """An identifier for the message this message replies to.""" + return self._reply_to_message + + @property + def dst_node_id(self) -> int: + """An identifier for the node receiving this message.""" + return self._dst_node_id + + @dst_node_id.setter + def dst_node_id(self, value: int) -> None: + """Set dst_node_id.""" + self._dst_node_id = value + + @property + def group_id(self) -> str: + """An identifier for grouping messages.""" + return self._group_id + + @group_id.setter + def group_id(self, value: str) -> None: + """Set group_id.""" + self._group_id = value + + @property + def ttl(self) -> str: + """Time-to-live for this message.""" + return self._ttl + + @ttl.setter + def ttl(self, value: str) -> None: + """Set ttl.""" + self._ttl = value + + @property + def message_type(self) -> str: + """A string that encodes the action to be executed on the receiving end.""" + return self._message_type + + @message_type.setter + def message_type(self, value: str) -> None: + """Set message_type.""" + self._message_type = value + + @property + def partition_id(self) -> int | None: + """An identifier telling which data partition a ClientApp should use.""" + return self._partition_id + + @partition_id.setter + def partition_id(self, value: int) -> None: + """Set patition_id.""" + self._partition_id = value @dataclass @@ -54,11 +161,64 @@ class Message: Parameters ---------- metadata : Metadata - A dataclass including information about the task to be executed. - message : RecordSet + A dataclass including information about the message to be executed. + content : RecordSet Holds records either sent by another entity (e.g. sent by the server-side logic to a client, or vice-versa) or that will be sent to it. """ - metadata: Metadata - message: RecordSet + _metadata: Metadata + _content: RecordSet + + def __init__(self, metadata: Metadata, content: RecordSet) -> None: + self._metadata = metadata + self._content = content + + @property + def metadata(self) -> Metadata: + """A dataclass including information about the message to be executed.""" + return self._metadata + + @property + def content(self) -> RecordSet: + """The content of this message.""" + return self._content + + @content.setter + def content(self, value: RecordSet) -> None: + """Set content.""" + self._content = value + + def create_reply(self, content: RecordSet, ttl: str) -> Message: + """Create a reply to this message with specified content and TTL. + + The method generates a new `Message` as a reply to this message. + It inherits 'run_id', 'src_node_id', 'dst_node_id', and 'message_type' from + this message and sets 'reply_to_message' to the ID of this message. + + Parameters + ---------- + content : RecordSet + The content for the reply message. + ttl : str + Time-to-live for this message. + + Returns + ------- + Message + A new `Message` instance representing the reply. + """ + return Message( + metadata=Metadata( + run_id=self.metadata.run_id, + message_id="", + src_node_id=self.metadata.dst_node_id, + dst_node_id=self.metadata.src_node_id, + reply_to_message=self.metadata.message_id, + group_id=self.metadata.group_id, + ttl=ttl, + message_type=self.metadata.message_type, + partition_id=self.metadata.partition_id, + ), + content=content, + ) diff --git a/src/py/flwr/common/metricsrecord.py b/src/py/flwr/common/metricsrecord.py deleted file mode 100644 index e70b0cb31d55..000000000000 --- a/src/py/flwr/common/metricsrecord.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright 2024 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""MetricsRecord.""" - - -from dataclasses import dataclass, field -from typing import Dict, Optional, get_args - -from .typing import MetricsRecordValues, MetricsScalar - - -@dataclass -class MetricsRecord: - """Metrics record.""" - - data: Dict[str, MetricsRecordValues] = field(default_factory=dict) - - def __init__( - self, - metrics_dict: Optional[Dict[str, MetricsRecordValues]] = None, - keep_input: bool = True, - ): - """Construct a MetricsRecord object. - - Parameters - ---------- - metrics_dict : Optional[Dict[str, MetricsRecordValues]] - A dictionary that stores basic types (i.e. `int`, `float` as defined - in `MetricsScalar`) and list of such types (see `MetricsScalarList`). - keep_input : bool (default: True) - A boolean indicating whether metrics should be deleted from the input - dictionary immediately after adding them to the record. When set - to True, the data is duplicated in memory. If memory is a concern, set - it to False. - """ - self.data = {} - if metrics_dict: - self.set_metrics(metrics_dict, keep_input=keep_input) - - def set_metrics( - self, metrics_dict: Dict[str, MetricsRecordValues], keep_input: bool = True - ) -> None: - """Add metrics to the record. - - Parameters - ---------- - metrics_dict : Dict[str, MetricsRecordValues] - A dictionary that stores basic types (i.e. `int`, `float` as defined - in `MetricsScalar`) and list of such types (see `MetricsScalarList`). - keep_input : bool (default: True) - A boolean indicating whether metrics should be deleted from the input - dictionary immediately after adding them to the record. When set - to True, the data is duplicated in memory. If memory is a concern, set - it to False. - """ - if any(not isinstance(k, str) for k in metrics_dict.keys()): - raise TypeError(f"Not all keys are of valid type. Expected {str}.") - - def is_valid(value: MetricsScalar) -> None: - """Check if value is of expected type.""" - if not isinstance(value, get_args(MetricsScalar)) or isinstance( - value, bool - ): - raise TypeError( - "Not all values are of valid type." - f" Expected {MetricsRecordValues} but you passed {type(value)}." - ) - - # Check types of values - # Split between those values that are list and those that aren't - # then process in the same way - for value in metrics_dict.values(): - if isinstance(value, list): - # If your lists are large (e.g. 1M+ elements) this will be slow - # 1s to check 10M element list on a M2 Pro - # In such settings, you'd be better of treating such metric as - # an array and pass it to a ParametersRecord. - # Empty lists are valid - if len(value) > 0: - is_valid(value[0]) - # all elements in the list must be of the same valid type - # this is needed for protobuf - value_type = type(value[0]) - if not all(isinstance(v, value_type) for v in value): - raise TypeError( - "All values in a list must be of the same valid type. " - f"One of {MetricsScalar}." - ) - else: - is_valid(value) - - # Add metrics to record - if keep_input: - # Copy - self.data = metrics_dict.copy() - else: - # Add entries to dataclass without duplicating memory - for key in list(metrics_dict.keys()): - self.data[key] = metrics_dict[key] - del metrics_dict[key] - - def __getitem__(self, key: str) -> MetricsRecordValues: - """Retrieve an element stored in record.""" - return self.data[key] diff --git a/src/py/flwr/common/record/__init__.py b/src/py/flwr/common/record/__init__.py new file mode 100644 index 000000000000..60bc54b8552a --- /dev/null +++ b/src/py/flwr/common/record/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Record APIs.""" + +from .configsrecord import ConfigsRecord +from .conversion_utils import array_from_numpy +from .metricsrecord import MetricsRecord +from .parametersrecord import Array, ParametersRecord +from .recordset import RecordSet + +__all__ = [ + "Array", + "array_from_numpy", + "ConfigsRecord", + "MetricsRecord", + "ParametersRecord", + "RecordSet", +] diff --git a/src/py/flwr/common/record/configsrecord.py b/src/py/flwr/common/record/configsrecord.py new file mode 100644 index 000000000000..704657601f50 --- /dev/null +++ b/src/py/flwr/common/record/configsrecord.py @@ -0,0 +1,87 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""ConfigsRecord.""" + + +from typing import Dict, Optional, get_args + +from flwr.common.typing import ConfigsRecordValues, ConfigsScalar + +from .typeddict import TypedDict + + +def _check_key(key: str) -> None: + """Check if key is of expected type.""" + if not isinstance(key, str): + raise TypeError(f"Key must be of type `str` but `{type(key)}` was passed.") + + +def _check_value(value: ConfigsRecordValues) -> None: + def is_valid(__v: ConfigsScalar) -> None: + """Check if value is of expected type.""" + if not isinstance(__v, get_args(ConfigsScalar)): + raise TypeError( + "Not all values are of valid type." + f" Expected `{ConfigsRecordValues}` but `{type(__v)}` was passed." + ) + + if isinstance(value, list): + # If your lists are large (e.g. 1M+ elements) this will be slow + # 1s to check 10M element list on a M2 Pro + # In such settings, you'd be better of treating such config as + # an array and pass it to a ParametersRecord. + # Empty lists are valid + if len(value) > 0: + is_valid(value[0]) + # all elements in the list must be of the same valid type + # this is needed for protobuf + value_type = type(value[0]) + if not all(isinstance(v, value_type) for v in value): + raise TypeError( + "All values in a list must be of the same valid type. " + f"One of {ConfigsScalar}." + ) + else: + is_valid(value) + + +class ConfigsRecord(TypedDict[str, ConfigsRecordValues]): + """Configs record.""" + + def __init__( + self, + configs_dict: Optional[Dict[str, ConfigsRecordValues]] = None, + keep_input: bool = True, + ) -> None: + """Construct a ConfigsRecord object. + + Parameters + ---------- + configs_dict : Optional[Dict[str, ConfigsRecordValues]] + A dictionary that stores basic types (i.e. `str`, `int`, `float`, `bytes` as + defined in `ConfigsScalar`) and lists of such types (see + `ConfigsScalarList`). + keep_input : bool (default: True) + A boolean indicating whether config passed should be deleted from the input + dictionary immediately after adding them to the record. When set + to True, the data is duplicated in memory. If memory is a concern, set + it to False. + """ + super().__init__(_check_key, _check_value) + if configs_dict: + for k in list(configs_dict.keys()): + self[k] = configs_dict[k] + if not keep_input: + del configs_dict[k] diff --git a/src/py/flwr/common/record/conversion_utils.py b/src/py/flwr/common/record/conversion_utils.py new file mode 100644 index 000000000000..7cc0b04283e9 --- /dev/null +++ b/src/py/flwr/common/record/conversion_utils.py @@ -0,0 +1,40 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Conversion utility functions for Records.""" + + +from io import BytesIO + +import numpy as np + +from ..constant import SType +from ..typing import NDArray +from .parametersrecord import Array + + +def array_from_numpy(ndarray: NDArray) -> Array: + """Create Array from NumPy ndarray.""" + buffer = BytesIO() + # WARNING: NEVER set allow_pickle to true. + # Reason: loading pickled data can execute arbitrary code + # Source: https://numpy.org/doc/stable/reference/generated/numpy.save.html + np.save(buffer, ndarray, allow_pickle=False) + data = buffer.getvalue() + return Array( + dtype=str(ndarray.dtype), + shape=list(ndarray.shape), + stype=SType.NUMPY, + data=data, + ) diff --git a/src/py/flwr/common/record/conversion_utils_test.py b/src/py/flwr/common/record/conversion_utils_test.py new file mode 100644 index 000000000000..84be37fda4a3 --- /dev/null +++ b/src/py/flwr/common/record/conversion_utils_test.py @@ -0,0 +1,44 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit tests for conversion utility functions.""" + + +import unittest +from io import BytesIO + +import numpy as np + +from ..constant import SType +from .conversion_utils import array_from_numpy + + +class TestArrayFromNumpy(unittest.TestCase): + """Unit tests for array_from_numpy.""" + + def test_array_from_numpy(self) -> None: + """Test the array_from_numpy function.""" + # Prepare + original_array = np.array([1, 2, 3], dtype=np.float32) + + # Execute + array_instance = array_from_numpy(original_array) + buffer = BytesIO(array_instance.data) + deserialized_array = np.load(buffer, allow_pickle=False) + + # Assert + self.assertEqual(array_instance.dtype, str(original_array.dtype)) + self.assertEqual(array_instance.shape, list(original_array.shape)) + self.assertEqual(array_instance.stype, SType.NUMPY) + np.testing.assert_array_equal(deserialized_array, original_array) diff --git a/src/py/flwr/common/record/metricsrecord.py b/src/py/flwr/common/record/metricsrecord.py new file mode 100644 index 000000000000..81b02303421b --- /dev/null +++ b/src/py/flwr/common/record/metricsrecord.py @@ -0,0 +1,86 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""MetricsRecord.""" + + +from typing import Dict, Optional, get_args + +from flwr.common.typing import MetricsRecordValues, MetricsScalar + +from .typeddict import TypedDict + + +def _check_key(key: str) -> None: + """Check if key is of expected type.""" + if not isinstance(key, str): + raise TypeError(f"Key must be of type `str` but `{type(key)}` was passed.") + + +def _check_value(value: MetricsRecordValues) -> None: + def is_valid(__v: MetricsScalar) -> None: + """Check if value is of expected type.""" + if not isinstance(__v, get_args(MetricsScalar)) or isinstance(__v, bool): + raise TypeError( + "Not all values are of valid type." + f" Expected `{MetricsRecordValues}` but `{type(__v)}` was passed." + ) + + if isinstance(value, list): + # If your lists are large (e.g. 1M+ elements) this will be slow + # 1s to check 10M element list on a M2 Pro + # In such settings, you'd be better of treating such metric as + # an array and pass it to a ParametersRecord. + # Empty lists are valid + if len(value) > 0: + is_valid(value[0]) + # all elements in the list must be of the same valid type + # this is needed for protobuf + value_type = type(value[0]) + if not all(isinstance(v, value_type) for v in value): + raise TypeError( + "All values in a list must be of the same valid type. " + f"One of {MetricsScalar}." + ) + else: + is_valid(value) + + +class MetricsRecord(TypedDict[str, MetricsRecordValues]): + """Metrics record.""" + + def __init__( + self, + metrics_dict: Optional[Dict[str, MetricsRecordValues]] = None, + keep_input: bool = True, + ): + """Construct a MetricsRecord object. + + Parameters + ---------- + metrics_dict : Optional[Dict[str, MetricsRecordValues]] + A dictionary that stores basic types (i.e. `int`, `float` as defined + in `MetricsScalar`) and list of such types (see `MetricsScalarList`). + keep_input : bool (default: True) + A boolean indicating whether metrics should be deleted from the input + dictionary immediately after adding them to the record. When set + to True, the data is duplicated in memory. If memory is a concern, set + it to False. + """ + super().__init__(_check_key, _check_value) + if metrics_dict: + for k in list(metrics_dict.keys()): + self[k] = metrics_dict[k] + if not keep_input: + del metrics_dict[k] diff --git a/src/py/flwr/common/parametersrecord.py b/src/py/flwr/common/record/parametersrecord.py similarity index 64% rename from src/py/flwr/common/parametersrecord.py rename to src/py/flwr/common/record/parametersrecord.py index ef02a0789ddf..17bf3f608db7 100644 --- a/src/py/flwr/common/parametersrecord.py +++ b/src/py/flwr/common/record/parametersrecord.py @@ -14,9 +14,15 @@ # ============================================================================== """ParametersRecord and Array.""" +from dataclasses import dataclass +from io import BytesIO +from typing import List, Optional, OrderedDict, cast -from dataclasses import dataclass, field -from typing import List, Optional, OrderedDict +import numpy as np + +from ..constant import SType +from ..typing import NDArray +from .typeddict import TypedDict @dataclass @@ -49,9 +55,35 @@ class Array: stype: str data: bytes + def numpy(self) -> NDArray: + """Return the array as a NumPy array.""" + if self.stype != SType.NUMPY: + raise TypeError( + f"Unsupported serialization type for numpy conversion: '{self.stype}'" + ) + bytes_io = BytesIO(self.data) + # WARNING: NEVER set allow_pickle to true. + # Reason: loading pickled data can execute arbitrary code + # Source: https://numpy.org/doc/stable/reference/generated/numpy.load.html + ndarray_deserialized = np.load(bytes_io, allow_pickle=False) + return cast(NDArray, ndarray_deserialized) + + +def _check_key(key: str) -> None: + """Check if key is of expected type.""" + if not isinstance(key, str): + raise TypeError(f"Key must be of type `str` but `{type(key)}` was passed.") + + +def _check_value(value: Array) -> None: + if not isinstance(value, Array): + raise TypeError( + f"Value must be of type `{Array}` but `{type(value)}` was passed." + ) + @dataclass -class ParametersRecord: +class ParametersRecord(TypedDict[str, Array]): """Parameters record. A dataclass storing named Arrays in order. This means that it holds entries as an @@ -59,8 +91,6 @@ class ParametersRecord: PyTorch's state_dict, but holding serialised tensors instead. """ - data: OrderedDict[str, Array] = field(default_factory=OrderedDict[str, Array]) - def __init__( self, array_dict: Optional[OrderedDict[str, Array]] = None, @@ -81,37 +111,9 @@ def __init__( parameters after adding it to the record, set this flag to True. When set to True, the data is duplicated in memory. """ - self.data = OrderedDict() + super().__init__(_check_key, _check_value) if array_dict: - self.set_parameters(array_dict, keep_input=keep_input) - - def set_parameters( - self, array_dict: OrderedDict[str, Array], keep_input: bool = False - ) -> None: - """Add parameters to record. - - Parameters - ---------- - array_dict : OrderedDict[str, Array] - A dictionary that stores serialized array-like or tensor-like objects. - keep_input : bool (default: False) - A boolean indicating whether parameters should be deleted from the input - dictionary immediately after adding them to the record. - """ - if any(not isinstance(k, str) for k in array_dict.keys()): - raise TypeError(f"Not all keys are of valid type. Expected {str}") - if any(not isinstance(v, Array) for v in array_dict.values()): - raise TypeError(f"Not all values are of valid type. Expected {Array}") - - if keep_input: - # Copy - self.data = OrderedDict(array_dict) - else: - # Add entries to dataclass without duplicating memory - for key in list(array_dict.keys()): - self.data[key] = array_dict[key] - del array_dict[key] - - def __getitem__(self, key: str) -> Array: - """Retrieve an element stored in record.""" - return self.data[key] + for k in list(array_dict.keys()): + self[k] = array_dict[k] + if not keep_input: + del array_dict[k] diff --git a/src/py/flwr/common/record/parametersrecord_test.py b/src/py/flwr/common/record/parametersrecord_test.py new file mode 100644 index 000000000000..9633af7bda6d --- /dev/null +++ b/src/py/flwr/common/record/parametersrecord_test.py @@ -0,0 +1,62 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit tests for ParametersRecord and Array.""" + + +import unittest +from io import BytesIO + +import numpy as np + +from ..constant import SType +from .parametersrecord import Array + + +class TestArray(unittest.TestCase): + """Unit tests for Array.""" + + def test_numpy_conversion_valid(self) -> None: + """Test the numpy method with valid Array instance.""" + # Prepare + original_array = np.array([1, 2, 3], dtype=np.float32) + buffer = BytesIO() + np.save(buffer, original_array, allow_pickle=False) + buffer.seek(0) + + # Execute + array_instance = Array( + dtype=str(original_array.dtype), + shape=list(original_array.shape), + stype=SType.NUMPY, + data=buffer.read(), + ) + converted_array = array_instance.numpy() + + # Assert + np.testing.assert_array_equal(converted_array, original_array) + + def test_numpy_conversion_invalid(self) -> None: + """Test the numpy method with invalid Array instance.""" + # Prepare + array_instance = Array( + dtype="float32", + shape=[3], + stype="invalid_stype", # Non-numpy stype + data=b"", + ) + + # Execute and assert + with self.assertRaises(TypeError): + array_instance.numpy() diff --git a/src/py/flwr/common/record/recordset.py b/src/py/flwr/common/record/recordset.py new file mode 100644 index 000000000000..d8ef44ab15c2 --- /dev/null +++ b/src/py/flwr/common/record/recordset.py @@ -0,0 +1,79 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""RecordSet.""" + + +from dataclasses import dataclass +from typing import Callable, Dict, Optional, Type, TypeVar + +from .configsrecord import ConfigsRecord +from .metricsrecord import MetricsRecord +from .parametersrecord import ParametersRecord +from .typeddict import TypedDict + +T = TypeVar("T") + + +@dataclass +class RecordSet: + """RecordSet stores groups of parameters, metrics and configs.""" + + _parameters_records: TypedDict[str, ParametersRecord] + _metrics_records: TypedDict[str, MetricsRecord] + _configs_records: TypedDict[str, ConfigsRecord] + + def __init__( + self, + parameters_records: Optional[Dict[str, ParametersRecord]] = None, + metrics_records: Optional[Dict[str, MetricsRecord]] = None, + configs_records: Optional[Dict[str, ConfigsRecord]] = None, + ) -> None: + def _get_check_fn(__t: Type[T]) -> Callable[[T], None]: + def _check_fn(__v: T) -> None: + if not isinstance(__v, __t): + raise TypeError(f"Expected `{__t}`, but `{type(__v)}` was passed.") + + return _check_fn + + self._parameters_records = TypedDict[str, ParametersRecord]( + _get_check_fn(str), _get_check_fn(ParametersRecord) + ) + self._metrics_records = TypedDict[str, MetricsRecord]( + _get_check_fn(str), _get_check_fn(MetricsRecord) + ) + self._configs_records = TypedDict[str, ConfigsRecord]( + _get_check_fn(str), _get_check_fn(ConfigsRecord) + ) + if parameters_records is not None: + self._parameters_records.update(parameters_records) + if metrics_records is not None: + self._metrics_records.update(metrics_records) + if configs_records is not None: + self._configs_records.update(configs_records) + + @property + def parameters_records(self) -> TypedDict[str, ParametersRecord]: + """Dictionary holding ParametersRecord instances.""" + return self._parameters_records + + @property + def metrics_records(self) -> TypedDict[str, MetricsRecord]: + """Dictionary holding MetricsRecord instances.""" + return self._metrics_records + + @property + def configs_records(self) -> TypedDict[str, ConfigsRecord]: + """Dictionary holding ConfigsRecord instances.""" + return self._configs_records diff --git a/src/py/flwr/common/recordset_test.py b/src/py/flwr/common/record/recordset_test.py similarity index 91% rename from src/py/flwr/common/recordset_test.py rename to src/py/flwr/common/record/recordset_test.py index cb199813f450..66560d8337f9 100644 --- a/src/py/flwr/common/recordset_test.py +++ b/src/py/flwr/common/record/recordset_test.py @@ -20,15 +20,12 @@ import numpy as np import pytest -from .configsrecord import ConfigsRecord -from .metricsrecord import MetricsRecord -from .parameter import ndarrays_to_parameters, parameters_to_ndarrays -from .parametersrecord import Array, ParametersRecord -from .recordset_compat import ( +from flwr.common.parameter import ndarrays_to_parameters, parameters_to_ndarrays +from flwr.common.recordset_compat import ( parameters_to_parametersrecord, parametersrecord_to_parameters, ) -from .typing import ( +from flwr.common.typing import ( ConfigsRecordValues, MetricsRecordValues, NDArray, @@ -36,6 +33,8 @@ Parameters, ) +from . import Array, ConfigsRecord, MetricsRecord, ParametersRecord + def get_ndarrays() -> NDArrays: """Return list of NumPy arrays.""" @@ -124,15 +123,14 @@ def test_parameters_to_parametersrecord_and_back( def test_set_parameters_while_keeping_intputs() -> None: """Tests keep_input functionality in ParametersRecord.""" # Adding parameters to a record that doesn't erase entries in the input `array_dict` - p_record = ParametersRecord(keep_input=True) array_dict = OrderedDict( {str(i): ndarray_to_array(ndarray) for i, ndarray in enumerate(get_ndarrays())} ) - p_record.set_parameters(array_dict, keep_input=True) + p_record = ParametersRecord(array_dict, keep_input=True) # Creating a second parametersrecord passing the same `array_dict` (not erased) p_record_2 = ParametersRecord(array_dict) - assert p_record.data == p_record_2.data + assert p_record == p_record_2 # Now it should be empty (the second ParametersRecord wasn't flagged to keep it) assert len(array_dict) == 0 @@ -144,7 +142,7 @@ def test_set_parameters_with_correct_types() -> None: array_dict = OrderedDict( {str(i): ndarray_to_array(ndarray) for i, ndarray in enumerate(get_ndarrays())} ) - p_record.set_parameters(array_dict) + p_record.update(array_dict) @pytest.mark.parametrize( @@ -169,7 +167,7 @@ def test_set_parameters_with_incorrect_types( } with pytest.raises(TypeError): - p_record.set_parameters(array_dict) # type: ignore + p_record.update(array_dict) @pytest.mark.parametrize( @@ -197,10 +195,10 @@ def test_set_metrics_to_metricsrecord_with_correct_types( ) # Add metric - m_record.set_metrics(my_metrics) + m_record.update(my_metrics) # Check metrics are actually added - assert my_metrics == m_record.data + assert my_metrics == m_record @pytest.mark.parametrize( @@ -250,7 +248,7 @@ def test_set_metrics_to_metricsrecord_with_incorrect_types( ) with pytest.raises(TypeError): - m_record.set_metrics(my_metrics) # type: ignore + m_record.update(my_metrics) @pytest.mark.parametrize( @@ -264,8 +262,6 @@ def test_set_metrics_to_metricsrecord_with_and_without_keeping_input( keep_input: bool, ) -> None: """Test keep_input functionality for MetricsRecord.""" - m_record = MetricsRecord(keep_input=keep_input) - # constructing a valid input labels = [1, 2.0] arrays = get_ndarrays() @@ -276,14 +272,14 @@ def test_set_metrics_to_metricsrecord_with_and_without_keeping_input( my_metrics_copy = my_metrics.copy() # Add metric - m_record.set_metrics(my_metrics, keep_input=keep_input) + m_record = MetricsRecord(my_metrics, keep_input=keep_input) # Check metrics are actually added # Check that input dict has been emptied when enabled such behaviour if keep_input: - assert my_metrics == m_record.data + assert my_metrics == m_record else: - assert my_metrics_copy == m_record.data + assert my_metrics_copy == m_record assert len(my_metrics) == 0 @@ -318,7 +314,7 @@ def test_set_configs_to_configsrecord_with_correct_types( c_record = ConfigsRecord(my_configs) # check values are actually there - assert c_record.data == my_configs + assert c_record == my_configs @pytest.mark.parametrize( @@ -352,14 +348,14 @@ def test_set_configs_to_configsrecord_with_incorrect_types( value_fn: Callable[[NDArray], Union[NDArray, Dict[str, NDArray], List[float]]], ) -> None: """Test adding configs of various unsupported types to a ConfigsRecord.""" - m_record = ConfigsRecord() + c_record = ConfigsRecord() labels = [1, 2.0] arrays = get_ndarrays() - my_metrics = OrderedDict( + my_configs = OrderedDict( {key_type(label): value_fn(arr) for label, arr in zip(labels, arrays)} ) with pytest.raises(TypeError): - m_record.set_configs(my_metrics) # type: ignore + c_record.update(my_configs) diff --git a/src/py/flwr/common/record/typeddict.py b/src/py/flwr/common/record/typeddict.py new file mode 100644 index 000000000000..23d70dc4f7e8 --- /dev/null +++ b/src/py/flwr/common/record/typeddict.py @@ -0,0 +1,113 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Typed dict base class for *Records.""" + + +from typing import Any, Callable, Dict, Generic, Iterator, Tuple, TypeVar, cast + +K = TypeVar("K") # Key type +V = TypeVar("V") # Value type + + +class TypedDict(Generic[K, V]): + """Typed dictionary.""" + + def __init__( + self, check_key_fn: Callable[[K], None], check_value_fn: Callable[[V], None] + ): + self._data: Dict[K, V] = {} + self._check_key_fn = check_key_fn + self._check_value_fn = check_value_fn + + def __setitem__(self, key: K, value: V) -> None: + """Set the given key to the given value after type checking.""" + # Check the types of key and value + self._check_key_fn(key) + self._check_value_fn(value) + # Set key-value pair + self._data[key] = value + + def __delitem__(self, key: K) -> None: + """Remove the item with the specified key.""" + del self._data[key] + + def __getitem__(self, item: K) -> V: + """Return the value for the specified key.""" + return self._data[item] + + def __iter__(self) -> Iterator[K]: + """Yield an iterator over the keys of the dictionary.""" + return iter(self._data) + + def __repr__(self) -> str: + """Return a string representation of the dictionary.""" + return self._data.__repr__() + + def __len__(self) -> int: + """Return the number of items in the dictionary.""" + return len(self._data) + + def __contains__(self, key: K) -> bool: + """Check if the dictionary contains the specified key.""" + return key in self._data + + def __eq__(self, other: object) -> bool: + """Compare this instance to another dictionary or TypedDict.""" + if isinstance(other, TypedDict): + return self._data == other._data + if isinstance(other, dict): + return self._data == other + return NotImplemented + + def items(self) -> Iterator[Tuple[K, V]]: + """R.items() -> a set-like object providing a view on R's items.""" + return cast(Iterator[Tuple[K, V]], self._data.items()) + + def keys(self) -> Iterator[K]: + """R.keys() -> a set-like object providing a view on R's keys.""" + return cast(Iterator[K], self._data.keys()) + + def values(self) -> Iterator[V]: + """R.values() -> an object providing a view on R's values.""" + return cast(Iterator[V], self._data.values()) + + def update(self, *args: Any, **kwargs: Any) -> None: + """R.update([E, ]**F) -> None. + + Update R from dict/iterable E and F. + """ + for key, value in dict(*args, **kwargs).items(): + self[key] = value + + def pop(self, key: K) -> V: + """R.pop(k[,d]) -> v, remove specified key and return the corresponding value. + + If key is not found, d is returned if given, otherwise KeyError is raised. + """ + return self._data.pop(key) + + def get(self, key: K, default: V) -> V: + """R.get(k[,d]) -> R[k] if k in R, else d. + + d defaults to None. + """ + return self._data.get(key, default) + + def clear(self) -> None: + """R.clear() -> None. + + Remove all items from R. + """ + self._data.clear() diff --git a/src/py/flwr/common/recordset.py b/src/py/flwr/common/recordset.py deleted file mode 100644 index 61c880c970b8..000000000000 --- a/src/py/flwr/common/recordset.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright 2024 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""RecordSet.""" - - -from dataclasses import dataclass, field -from typing import Dict - -from .configsrecord import ConfigsRecord -from .metricsrecord import MetricsRecord -from .parametersrecord import ParametersRecord - - -@dataclass -class RecordSet: - """Definition of RecordSet.""" - - parameters: Dict[str, ParametersRecord] = field(default_factory=dict) - metrics: Dict[str, MetricsRecord] = field(default_factory=dict) - configs: Dict[str, ConfigsRecord] = field(default_factory=dict) - - def set_parameters(self, name: str, record: ParametersRecord) -> None: - """Add a ParametersRecord.""" - self.parameters[name] = record - - def get_parameters(self, name: str) -> ParametersRecord: - """Get a ParametesRecord.""" - return self.parameters[name] - - def del_parameters(self, name: str) -> None: - """Delete a ParametersRecord.""" - del self.parameters[name] - - def set_metrics(self, name: str, record: MetricsRecord) -> None: - """Add a MetricsRecord.""" - self.metrics[name] = record - - def get_metrics(self, name: str) -> MetricsRecord: - """Get a MetricsRecord.""" - return self.metrics[name] - - def del_metrics(self, name: str) -> None: - """Delete a MetricsRecord.""" - del self.metrics[name] - - def set_configs(self, name: str, record: ConfigsRecord) -> None: - """Add a ConfigsRecord.""" - self.configs[name] = record - - def get_configs(self, name: str) -> ConfigsRecord: - """Get a ConfigsRecord.""" - return self.configs[name] - - def del_configs(self, name: str) -> None: - """Delete a ConfigsRecord.""" - del self.configs[name] diff --git a/src/py/flwr/common/recordset_compat.py b/src/py/flwr/common/recordset_compat.py index e0e591048820..394ea1353bab 100644 --- a/src/py/flwr/common/recordset_compat.py +++ b/src/py/flwr/common/recordset_compat.py @@ -17,10 +17,7 @@ from typing import Dict, Mapping, OrderedDict, Tuple, Union, cast, get_args -from .configsrecord import ConfigsRecord -from .metricsrecord import MetricsRecord -from .parametersrecord import Array, ParametersRecord -from .recordset import RecordSet +from . import Array, ConfigsRecord, MetricsRecord, ParametersRecord, RecordSet from .typing import ( Code, ConfigsRecordValues, @@ -61,7 +58,7 @@ def parametersrecord_to_parameters( """ parameters = Parameters(tensors=[], tensor_type="") - for key in list(record.data.keys()): + for key in list(record.keys()): parameters.tensors.append(record[key].data) if not parameters.tensor_type: @@ -70,7 +67,7 @@ def parametersrecord_to_parameters( parameters.tensor_type = record[key].stype if not keep_input: - del record.data[key] + del record[key] return parameters @@ -95,8 +92,6 @@ def parameters_to_parametersrecord( """ tensor_type = parameters.tensor_type - p_record = ParametersRecord() - num_arrays = len(parameters.tensors) ordered_dict = OrderedDict() for idx in range(num_arrays): @@ -108,8 +103,7 @@ def parameters_to_parametersrecord( data=tensor, dtype="", stype=tensor_type, shape=[] ) - p_record.set_parameters(ordered_dict, keep_input=keep_input) - return p_record + return ParametersRecord(ordered_dict, keep_input=keep_input) def _check_mapping_from_recordscalartype_to_scalar( @@ -135,16 +129,16 @@ def _recordset_to_fit_or_evaluate_ins_components( ) -> Tuple[Parameters, Dict[str, Scalar]]: """Derive Fit/Evaluate Ins from a RecordSet.""" # get Array and construct Parameters - parameters_record = recordset.get_parameters(f"{ins_str}.parameters") + parameters_record = recordset.parameters_records[f"{ins_str}.parameters"] parameters = parametersrecord_to_parameters( parameters_record, keep_input=keep_input ) # get config dict - config_record = recordset.get_configs(f"{ins_str}.config") - - config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record.data) + config_record = recordset.configs_records[f"{ins_str}.config"] + # pylint: disable-next=protected-access + config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record._data) return parameters, config_dict @@ -155,13 +149,11 @@ def _fit_or_evaluate_ins_to_recordset( recordset = RecordSet() ins_str = "fitins" if isinstance(ins, FitIns) else "evaluateins" - recordset.set_parameters( - name=f"{ins_str}.parameters", - record=parameters_to_parametersrecord(ins.parameters, keep_input=keep_input), - ) + parametersrecord = parameters_to_parametersrecord(ins.parameters, keep_input) + recordset.parameters_records[f"{ins_str}.parameters"] = parametersrecord - recordset.set_configs( - name=f"{ins_str}.config", record=ConfigsRecord(ins.config) # type: ignore + recordset.configs_records[f"{ins_str}.config"] = ConfigsRecord( + ins.config # type: ignore ) return recordset @@ -176,12 +168,12 @@ def _embed_status_into_recordset( } # we add it to a `ConfigsRecord`` because the `status.message`` is a string # and `str` values aren't supported in `MetricsRecords` - recordset.set_configs(f"{res_str}.status", record=ConfigsRecord(status_dict)) + recordset.configs_records[f"{res_str}.status"] = ConfigsRecord(status_dict) return recordset def _extract_status_from_recordset(res_str: str, recordset: RecordSet) -> Status: - status = recordset.get_configs(f"{res_str}.status") + status = recordset.configs_records[f"{res_str}.status"] code = cast(int, status["code"]) return Status(code=Code(code), message=str(status["message"])) @@ -206,15 +198,15 @@ def recordset_to_fitres(recordset: RecordSet, keep_input: bool) -> FitRes: """Derive FitRes from a RecordSet object.""" ins_str = "fitres" parameters = parametersrecord_to_parameters( - recordset.get_parameters(f"{ins_str}.parameters"), keep_input=keep_input + recordset.parameters_records[f"{ins_str}.parameters"], keep_input=keep_input ) num_examples = cast( - int, recordset.get_metrics(f"{ins_str}.num_examples")["num_examples"] + int, recordset.metrics_records[f"{ins_str}.num_examples"]["num_examples"] ) - configs_record = recordset.get_configs(f"{ins_str}.metrics") - - metrics = _check_mapping_from_recordscalartype_to_scalar(configs_record.data) + configs_record = recordset.configs_records[f"{ins_str}.metrics"] + # pylint: disable-next=protected-access + metrics = _check_mapping_from_recordscalartype_to_scalar(configs_record._data) status = _extract_status_from_recordset(ins_str, recordset) return FitRes( @@ -228,16 +220,17 @@ def fitres_to_recordset(fitres: FitRes, keep_input: bool) -> RecordSet: res_str = "fitres" - recordset.set_configs( - name=f"{res_str}.metrics", record=ConfigsRecord(fitres.metrics) # type: ignore + recordset.configs_records[f"{res_str}.metrics"] = ConfigsRecord( + fitres.metrics # type: ignore ) - recordset.set_metrics( - name=f"{res_str}.num_examples", - record=MetricsRecord({"num_examples": fitres.num_examples}), + recordset.metrics_records[f"{res_str}.num_examples"] = MetricsRecord( + {"num_examples": fitres.num_examples}, ) - recordset.set_parameters( - name=f"{res_str}.parameters", - record=parameters_to_parametersrecord(fitres.parameters, keep_input), + recordset.parameters_records[f"{res_str}.parameters"] = ( + parameters_to_parametersrecord( + fitres.parameters, + keep_input, + ) ) # status @@ -266,14 +259,15 @@ def recordset_to_evaluateres(recordset: RecordSet) -> EvaluateRes: """Derive EvaluateRes from a RecordSet object.""" ins_str = "evaluateres" - loss = cast(int, recordset.get_metrics(f"{ins_str}.loss")["loss"]) + loss = cast(int, recordset.metrics_records[f"{ins_str}.loss"]["loss"]) num_examples = cast( - int, recordset.get_metrics(f"{ins_str}.num_examples")["num_examples"] + int, recordset.metrics_records[f"{ins_str}.num_examples"]["num_examples"] ) - configs_record = recordset.get_configs(f"{ins_str}.metrics") + configs_record = recordset.configs_records[f"{ins_str}.metrics"] - metrics = _check_mapping_from_recordscalartype_to_scalar(configs_record.data) + # pylint: disable-next=protected-access + metrics = _check_mapping_from_recordscalartype_to_scalar(configs_record._data) status = _extract_status_from_recordset(ins_str, recordset) return EvaluateRes( @@ -287,21 +281,18 @@ def evaluateres_to_recordset(evaluateres: EvaluateRes) -> RecordSet: res_str = "evaluateres" # loss - recordset.set_metrics( - name=f"{res_str}.loss", - record=MetricsRecord({"loss": evaluateres.loss}), + recordset.metrics_records[f"{res_str}.loss"] = MetricsRecord( + {"loss": evaluateres.loss}, ) # num_examples - recordset.set_metrics( - name=f"{res_str}.num_examples", - record=MetricsRecord({"num_examples": evaluateres.num_examples}), + recordset.metrics_records[f"{res_str}.num_examples"] = MetricsRecord( + {"num_examples": evaluateres.num_examples}, ) # metrics - recordset.set_configs( - name=f"{res_str}.metrics", - record=ConfigsRecord(evaluateres.metrics), # type: ignore + recordset.configs_records[f"{res_str}.metrics"] = ConfigsRecord( + evaluateres.metrics, # type: ignore ) # status @@ -314,9 +305,9 @@ def evaluateres_to_recordset(evaluateres: EvaluateRes) -> RecordSet: def recordset_to_getparametersins(recordset: RecordSet) -> GetParametersIns: """Derive GetParametersIns from a RecordSet object.""" - config_record = recordset.get_configs("getparametersins.config") - - config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record.data) + config_record = recordset.configs_records["getparametersins.config"] + # pylint: disable-next=protected-access + config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record._data) return GetParametersIns(config=config_dict) @@ -325,9 +316,8 @@ def getparametersins_to_recordset(getparameters_ins: GetParametersIns) -> Record """Construct a RecordSet from a GetParametersIns object.""" recordset = RecordSet() - recordset.set_configs( - name="getparametersins.config", - record=ConfigsRecord(getparameters_ins.config), # type: ignore + recordset.configs_records["getparametersins.config"] = ConfigsRecord( + getparameters_ins.config, # type: ignore ) return recordset @@ -341,7 +331,7 @@ def getparametersres_to_recordset( parameters_record = parameters_to_parametersrecord( getparametersres.parameters, keep_input=keep_input ) - recordset.set_parameters(f"{res_str}.parameters", parameters_record) + recordset.parameters_records[f"{res_str}.parameters"] = parameters_record # status recordset = _embed_status_into_recordset( @@ -357,7 +347,7 @@ def recordset_to_getparametersres( """Derive GetParametersRes from a RecordSet object.""" res_str = "getparametersres" parameters = parametersrecord_to_parameters( - recordset.get_parameters(f"{res_str}.parameters"), keep_input=keep_input + recordset.parameters_records[f"{res_str}.parameters"], keep_input=keep_input ) status = _extract_status_from_recordset(res_str, recordset) @@ -366,8 +356,9 @@ def recordset_to_getparametersres( def recordset_to_getpropertiesins(recordset: RecordSet) -> GetPropertiesIns: """Derive GetPropertiesIns from a RecordSet object.""" - config_record = recordset.get_configs("getpropertiesins.config") - config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record.data) + config_record = recordset.configs_records["getpropertiesins.config"] + # pylint: disable-next=protected-access + config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record._data) return GetPropertiesIns(config=config_dict) @@ -375,9 +366,8 @@ def recordset_to_getpropertiesins(recordset: RecordSet) -> GetPropertiesIns: def getpropertiesins_to_recordset(getpropertiesins: GetPropertiesIns) -> RecordSet: """Construct a RecordSet from a GetPropertiesRes object.""" recordset = RecordSet() - recordset.set_configs( - name="getpropertiesins.config", - record=ConfigsRecord(getpropertiesins.config), # type: ignore + recordset.configs_records["getpropertiesins.config"] = ConfigsRecord( + getpropertiesins.config, # type: ignore ) return recordset @@ -385,8 +375,9 @@ def getpropertiesins_to_recordset(getpropertiesins: GetPropertiesIns) -> RecordS def recordset_to_getpropertiesres(recordset: RecordSet) -> GetPropertiesRes: """Derive GetPropertiesRes from a RecordSet object.""" res_str = "getpropertiesres" - config_record = recordset.get_configs(f"{res_str}.properties") - properties = _check_mapping_from_recordscalartype_to_scalar(config_record.data) + config_record = recordset.configs_records[f"{res_str}.properties"] + # pylint: disable-next=protected-access + properties = _check_mapping_from_recordscalartype_to_scalar(config_record._data) status = _extract_status_from_recordset(res_str, recordset=recordset) @@ -397,9 +388,8 @@ def getpropertiesres_to_recordset(getpropertiesres: GetPropertiesRes) -> RecordS """Construct a RecordSet from a GetPropertiesRes object.""" recordset = RecordSet() res_str = "getpropertiesres" - recordset.set_configs( - name=f"{res_str}.properties", - record=ConfigsRecord(getpropertiesres.properties), # type: ignore + recordset.configs_records[f"{res_str}.properties"] = ConfigsRecord( + getpropertiesres.properties, # type: ignore ) # status recordset = _embed_status_into_recordset( diff --git a/src/py/flwr/common/serde.py b/src/py/flwr/common/serde.py index 92c4e2cdad00..531a4bde6e9d 100644 --- a/src/py/flwr/common/serde.py +++ b/src/py/flwr/common/serde.py @@ -20,6 +20,7 @@ from google.protobuf.message import Message as GrpcMessage # pylint: disable=E0611 +from flwr.proto.node_pb2 import Node from flwr.proto.recordset_pb2 import Array as ProtoArray from flwr.proto.recordset_pb2 import BoolList, BytesList from flwr.proto.recordset_pb2 import ConfigsRecord as ProtoConfigsRecord @@ -42,12 +43,9 @@ ) # pylint: enable=E0611 -from . import typing -from .configsrecord import ConfigsRecord +from . import Array, ConfigsRecord, MetricsRecord, ParametersRecord, RecordSet, typing from .message import Message, Metadata -from .metricsrecord import MetricsRecord -from .parametersrecord import Array, ParametersRecord -from .recordset import RecordSet +from .record.typeddict import TypedDict # === Parameters message === @@ -413,7 +411,9 @@ def _record_value_from_proto(value_proto: GrpcMessage) -> Any: def _record_value_dict_to_proto( - value_dict: Dict[str, Any], allowed_types: List[type], value_proto_class: Type[T] + value_dict: TypedDict[str, Any], + allowed_types: List[type], + value_proto_class: Type[T], ) -> Dict[str, T]: """Serialize the record value dict to ProtoBuf. @@ -455,8 +455,8 @@ def array_from_proto(array_proto: ProtoArray) -> Array: def parameters_record_to_proto(record: ParametersRecord) -> ProtoParametersRecord: """Serialize ParametersRecord to ProtoBuf.""" return ProtoParametersRecord( - data_keys=record.data.keys(), - data_values=map(array_to_proto, record.data.values()), + data_keys=record.keys(), + data_values=map(array_to_proto, record.values()), ) @@ -475,9 +475,7 @@ def parameters_record_from_proto( def metrics_record_to_proto(record: MetricsRecord) -> ProtoMetricsRecord: """Serialize MetricsRecord to ProtoBuf.""" return ProtoMetricsRecord( - data=_record_value_dict_to_proto( - record.data, [float, int], ProtoMetricsRecordValue - ) + data=_record_value_dict_to_proto(record, [float, int], ProtoMetricsRecordValue) ) @@ -496,7 +494,9 @@ def configs_record_to_proto(record: ConfigsRecord) -> ProtoConfigsRecord: """Serialize ConfigsRecord to ProtoBuf.""" return ProtoConfigsRecord( data=_record_value_dict_to_proto( - record.data, [bool, int, float, str, bytes], ProtoConfigsRecordValue + record, + [bool, int, float, str, bytes], + ProtoConfigsRecordValue, ) ) @@ -519,24 +519,29 @@ def recordset_to_proto(recordset: RecordSet) -> ProtoRecordSet: """Serialize RecordSet to ProtoBuf.""" return ProtoRecordSet( parameters={ - k: parameters_record_to_proto(v) for k, v in recordset.parameters.items() + k: parameters_record_to_proto(v) + for k, v in recordset.parameters_records.items() + }, + metrics={ + k: metrics_record_to_proto(v) for k, v in recordset.metrics_records.items() + }, + configs={ + k: configs_record_to_proto(v) for k, v in recordset.configs_records.items() }, - metrics={k: metrics_record_to_proto(v) for k, v in recordset.metrics.items()}, - configs={k: configs_record_to_proto(v) for k, v in recordset.configs.items()}, ) def recordset_from_proto(recordset_proto: ProtoRecordSet) -> RecordSet: """Deserialize RecordSet from ProtoBuf.""" return RecordSet( - parameters={ + parameters_records={ k: parameters_record_from_proto(v) for k, v in recordset_proto.parameters.items() }, - metrics={ + metrics_records={ k: metrics_record_from_proto(v) for k, v in recordset_proto.metrics.items() }, - configs={ + configs_records={ k: configs_record_from_proto(v) for k, v in recordset_proto.configs.items() }, ) @@ -547,11 +552,17 @@ def recordset_from_proto(recordset_proto: ProtoRecordSet) -> RecordSet: def message_to_taskins(message: Message) -> TaskIns: """Create a TaskIns from the Message.""" + md = message.metadata return TaskIns( + group_id=md.group_id, + run_id=md.run_id, task=Task( - ttl=message.metadata.ttl, - task_type=message.metadata.task_type, - recordset=recordset_to_proto(message.message), + producer=Node(node_id=0, anonymous=True), # Assume driver node + consumer=Node(node_id=md.dst_node_id, anonymous=False), + ttl=md.ttl, + ancestry=[md.reply_to_message] if md.reply_to_message != "" else [], + task_type=md.message_type, + recordset=recordset_to_proto(message.content), ), ) @@ -561,26 +572,36 @@ def message_from_taskins(taskins: TaskIns) -> Message: # Retrieve the Metadata metadata = Metadata( run_id=taskins.run_id, - task_id=taskins.task_id, + message_id=taskins.task_id, + src_node_id=taskins.task.producer.node_id, + dst_node_id=taskins.task.consumer.node_id, + reply_to_message=taskins.task.ancestry[0] if taskins.task.ancestry else "", group_id=taskins.group_id, ttl=taskins.task.ttl, - task_type=taskins.task.task_type, + message_type=taskins.task.task_type, ) # Return the Message return Message( metadata=metadata, - message=recordset_from_proto(taskins.task.recordset), + content=recordset_from_proto(taskins.task.recordset), ) def message_to_taskres(message: Message) -> TaskRes: """Create a TaskRes from the Message.""" + md = message.metadata return TaskRes( + task_id="", # This will be generated by the server + group_id=md.group_id, + run_id=md.run_id, task=Task( - ttl=message.metadata.ttl, - task_type=message.metadata.task_type, - recordset=recordset_to_proto(message.message), + producer=Node(node_id=md.src_node_id, anonymous=False), + consumer=Node(node_id=0, anonymous=True), # Assume driver node + ttl=md.ttl, + ancestry=[md.reply_to_message] if md.reply_to_message != "" else [], + task_type=md.message_type, + recordset=recordset_to_proto(message.content), ), ) @@ -590,14 +611,17 @@ def message_from_taskres(taskres: TaskRes) -> Message: # Retrieve the MetaData metadata = Metadata( run_id=taskres.run_id, - task_id=taskres.task_id, + message_id=taskres.task_id, + src_node_id=taskres.task.producer.node_id, + dst_node_id=taskres.task.consumer.node_id, + reply_to_message=taskres.task.ancestry[0] if taskres.task.ancestry else "", group_id=taskres.group_id, ttl=taskres.task.ttl, - task_type=taskres.task.task_type, + message_type=taskres.task.task_type, ) # Return the Message return Message( metadata=metadata, - message=recordset_from_proto(taskres.task.recordset), + content=recordset_from_proto(taskres.task.recordset), ) diff --git a/src/py/flwr/common/serde_test.py b/src/py/flwr/common/serde_test.py index 2a229a87e399..1f25fd1852c1 100644 --- a/src/py/flwr/common/serde_test.py +++ b/src/py/flwr/common/serde_test.py @@ -28,12 +28,8 @@ from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet # pylint: enable=E0611 -from . import typing -from .configsrecord import ConfigsRecord +from . import Array, ConfigsRecord, MetricsRecord, ParametersRecord, RecordSet, typing from .message import Message, Metadata -from .metricsrecord import MetricsRecord -from .parametersrecord import Array, ParametersRecord -from .recordset import RecordSet from .serde import ( array_from_proto, array_to_proto, @@ -199,15 +195,15 @@ def recordset( ) -> RecordSet: """Create a RecordSet.""" return RecordSet( - parameters={ + parameters_records={ self.get_str(): self.parameters_record() for _ in range(num_params_records) }, - metrics={ + metrics_records={ self.get_str(): self.metrics_record() for _ in range(num_metrics_records) }, - configs={ + configs_records={ self.get_str(): self.configs_record() for _ in range(num_configs_records) }, @@ -217,10 +213,13 @@ def metadata(self) -> Metadata: """Create a Metadata.""" return Metadata( run_id=self.rng.randint(0, 1 << 30), - task_id=self.get_str(64), + message_id=self.get_str(64), group_id=self.get_str(30), + src_node_id=self.rng.randint(0, 1 << 63), + dst_node_id=self.rng.randint(0, 1 << 63), + reply_to_message=self.get_str(64), ttl=self.get_str(10), - task_type=self.get_str(10), + message_type=self.get_str(10), ) @@ -251,7 +250,7 @@ def test_parameters_record_serialization_deserialization() -> None: # Assert assert isinstance(proto, ProtoParametersRecord) - assert original.data == deserialized.data + assert original == deserialized def test_metrics_record_serialization_deserialization() -> None: @@ -266,7 +265,7 @@ def test_metrics_record_serialization_deserialization() -> None: # Assert assert isinstance(proto, ProtoMetricsRecord) - assert original.data == deserialized.data + assert original == deserialized def test_configs_record_serialization_deserialization() -> None: @@ -281,7 +280,7 @@ def test_configs_record_serialization_deserialization() -> None: # Assert assert isinstance(proto, ProtoConfigsRecord) - assert original.data == deserialized.data + assert original == deserialized def test_recordset_serialization_deserialization() -> None: @@ -304,26 +303,20 @@ def test_message_to_and_from_taskins() -> None: # Prepare maker = RecordMaker(state=1) metadata = maker.metadata() + # pylint: disable-next=protected-access + metadata._src_node_id = 0 # Assume driver node original = Message( - metadata=Metadata( - run_id=0, - task_id="", - group_id="", - ttl=metadata.ttl, - task_type=metadata.task_type, - ), - message=maker.recordset(1, 1, 1), + metadata=metadata, + content=maker.recordset(1, 1, 1), ) # Execute taskins = message_to_taskins(original) - taskins.run_id = metadata.run_id - taskins.task_id = metadata.task_id - taskins.group_id = metadata.group_id + taskins.task_id = metadata.message_id deserialized = message_from_taskins(taskins) # Assert - assert original.message == deserialized.message + assert original.content == deserialized.content assert metadata == deserialized.metadata @@ -332,24 +325,17 @@ def test_message_to_and_from_taskres() -> None: # Prepare maker = RecordMaker(state=2) metadata = maker.metadata() + metadata.dst_node_id = 0 # Assume driver node original = Message( - metadata=Metadata( - run_id=0, - task_id="", - group_id="", - ttl=metadata.ttl, - task_type=metadata.task_type, - ), - message=maker.recordset(1, 1, 1), + metadata=metadata, + content=maker.recordset(1, 1, 1), ) # Execute taskres = message_to_taskres(original) - taskres.run_id = metadata.run_id - taskres.task_id = metadata.task_id - taskres.group_id = metadata.group_id + taskres.task_id = metadata.message_id deserialized = message_from_taskres(taskres) # Assert - assert original.message == deserialized.message + assert original.content == deserialized.content assert metadata == deserialized.metadata diff --git a/src/py/flwr/common/telemetry.py b/src/py/flwr/common/telemetry.py index 2087c8b66ab5..8eb594085d31 100644 --- a/src/py/flwr/common/telemetry.py +++ b/src/py/flwr/common/telemetry.py @@ -32,7 +32,7 @@ FLWR_TELEMETRY_ENABLED = os.getenv("FLWR_TELEMETRY_ENABLED", "1") FLWR_TELEMETRY_LOGGING = os.getenv("FLWR_TELEMETRY_LOGGING", "0") -TELEMETRY_EVENTS_URL = "https://telemetry.flower.dev/api/v1/event" +TELEMETRY_EVENTS_URL = "https://telemetry.flower.ai/api/v1/event" LOGGER_NAME = "flwr-telemetry" LOGGER_LEVEL = logging.DEBUG diff --git a/src/py/flwr/common/version.py b/src/py/flwr/common/version.py index 9545b4b68073..6808c66606b1 100644 --- a/src/py/flwr/common/version.py +++ b/src/py/flwr/common/version.py @@ -1,6 +1,5 @@ """Flower package version helper.""" - import importlib.metadata as importlib_metadata from typing import Tuple diff --git a/src/py/flwr/server/__init__.py b/src/py/flwr/server/__init__.py index 84c24b4bc2c1..633bd668b520 100644 --- a/src/py/flwr/server/__init__.py +++ b/src/py/flwr/server/__init__.py @@ -15,29 +15,38 @@ """Flower server.""" -from . import driver, strategy -from .app import ServerConfig as ServerConfig +from . import strategy +from . import workflow as workflow from .app import run_driver_api as run_driver_api from .app import run_fleet_api as run_fleet_api -from .app import run_server_app as run_server_app from .app import run_superlink as run_superlink from .app import start_server as start_server from .client_manager import ClientManager as ClientManager from .client_manager import SimpleClientManager as SimpleClientManager +from .compat import LegacyContext as LegacyContext +from .compat import start_driver as start_driver +from .driver import Driver as Driver from .history import History as History +from .run_serverapp import run_server_app as run_server_app from .server import Server as Server +from .server_app import ServerApp as ServerApp +from .server_config import ServerConfig as ServerConfig __all__ = [ "ClientManager", - "driver", + "Driver", "History", + "LegacyContext", "run_driver_api", "run_fleet_api", "run_server_app", "run_superlink", "Server", + "ServerApp", "ServerConfig", "SimpleClientManager", + "start_driver", "start_server", "strategy", + "workflow", ] diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index 6e2f1f7dc88d..a4913d51315b 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -19,12 +19,9 @@ import importlib.util import sys import threading -from dataclasses import dataclass -from logging import DEBUG, ERROR, INFO, WARN +from logging import ERROR, INFO, WARN from os.path import isfile from pathlib import Path -from signal import SIGINT, SIGTERM, signal -from types import FrameType from typing import List, Optional, Tuple import grpc @@ -35,7 +32,9 @@ MISSING_EXTRA_REST, TRANSPORT_TYPE_GRPC_RERE, TRANSPORT_TYPE_REST, + TRANSPORT_TYPE_VCE, ) +from flwr.common.exit_handlers import register_exit_handlers from flwr.common.logger import log from flwr.proto.driver_pb2_grpc import ( # pylint: disable=E0611 add_DriverServicer_to_server, @@ -43,17 +42,20 @@ from flwr.proto.fleet_pb2_grpc import ( # pylint: disable=E0611 add_FleetServicer_to_server, ) -from flwr.server.client_manager import ClientManager, SimpleClientManager -from flwr.server.history import History -from flwr.server.server import Server -from flwr.server.strategy import FedAvg, Strategy -from flwr.server.superlink.driver.driver_servicer import DriverServicer -from flwr.server.superlink.fleet.grpc_bidi.grpc_server import ( + +from .client_manager import ClientManager +from .history import History +from .server import Server, init_defaults, run_fl +from .server_config import ServerConfig +from .strategy import Strategy +from .superlink.driver.driver_servicer import DriverServicer +from .superlink.fleet.grpc_bidi.grpc_server import ( generic_create_grpc_server, start_grpc_server, ) -from flwr.server.superlink.fleet.grpc_rere.fleet_servicer import FleetServicer -from flwr.server.superlink.state import StateFactory +from .superlink.fleet.grpc_rere.fleet_servicer import FleetServicer +from .superlink.fleet.vce import start_vce +from .superlink.state import StateFactory ADDRESS_DRIVER_API = "0.0.0.0:9091" ADDRESS_FLEET_API_GRPC_RERE = "0.0.0.0:9092" @@ -63,72 +65,6 @@ DATABASE = ":flwr-in-memory-state:" -@dataclass -class ServerConfig: - """Flower server config. - - All attributes have default values which allows users to configure just the ones - they care about. - """ - - num_rounds: int = 1 - round_timeout: Optional[float] = None - - -def run_server_app() -> None: - """Run Flower server app.""" - event(EventType.RUN_SERVER_APP_ENTER) - - args = _parse_args_run_server_app().parse_args() - - # Obtain certificates - if args.insecure: - if args.root_certificates is not None: - sys.exit( - "Conflicting options: The '--insecure' flag disables HTTPS, " - "but '--root-certificates' was also specified. Please remove " - "the '--root-certificates' option when running in insecure mode, " - "or omit '--insecure' to use HTTPS." - ) - log( - WARN, - "Option `--insecure` was set. " - "Starting insecure HTTP client connected to %s.", - args.server, - ) - root_certificates = None - else: - # Load the certificates if provided, or load the system certificates - cert_path = args.root_certificates - if cert_path is None: - root_certificates = None - else: - root_certificates = Path(cert_path).read_bytes() - log( - DEBUG, - "Starting secure HTTPS client connected to %s " - "with the following certificates: %s.", - args.server, - cert_path, - ) - - log( - DEBUG, - "Flower will load ServerApp `%s`", - getattr(args, "server-app"), - ) - - log( - DEBUG, - "root_certificates: `%s`", - root_certificates, - ) - - log(WARN, "Not implemented: run_server_app") - - event(EventType.RUN_SERVER_APP_LEAVE) - - def start_server( # pylint: disable=too-many-arguments,too-many-locals *, server_address: str = ADDRESS_FLEET_API_GRPC_BIDI, @@ -248,47 +184,6 @@ def start_server( # pylint: disable=too-many-arguments,too-many-locals return hist -def init_defaults( - server: Optional[Server], - config: Optional[ServerConfig], - strategy: Optional[Strategy], - client_manager: Optional[ClientManager], -) -> Tuple[Server, ServerConfig]: - """Create server instance if none was given.""" - if server is None: - if client_manager is None: - client_manager = SimpleClientManager() - if strategy is None: - strategy = FedAvg() - server = Server(client_manager=client_manager, strategy=strategy) - elif strategy is not None: - log(WARN, "Both server and strategy were provided, ignoring strategy") - - # Set default config values - if config is None: - config = ServerConfig() - - return server, config - - -def run_fl( - server: Server, - config: ServerConfig, -) -> History: - """Train a model on the given server and return the History object.""" - hist = server.fit(num_rounds=config.num_rounds, timeout=config.round_timeout) - log(INFO, "app_fit: losses_distributed %s", str(hist.losses_distributed)) - log(INFO, "app_fit: metrics_distributed_fit %s", str(hist.metrics_distributed_fit)) - log(INFO, "app_fit: metrics_distributed %s", str(hist.metrics_distributed)) - log(INFO, "app_fit: losses_centralized %s", str(hist.losses_centralized)) - log(INFO, "app_fit: metrics_centralized %s", str(hist.metrics_centralized)) - - # Graceful shutdown - server.disconnect_all_clients(timeout=config.round_timeout) - - return hist - - def run_driver_api() -> None: """Run Flower server (Driver API).""" log(INFO, "Starting Flower server (Driver API)") @@ -316,10 +211,10 @@ def run_driver_api() -> None: ) # Graceful shutdown - _register_exit_handlers( + register_exit_handlers( + event_type=EventType.RUN_DRIVER_API_LEAVE, grpc_servers=[grpc_server], bckg_threads=[], - event_type=EventType.RUN_DRIVER_API_LEAVE, ) # Block @@ -384,10 +279,10 @@ def run_fleet_api() -> None: raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}") # Graceful shutdown - _register_exit_handlers( + register_exit_handlers( + event_type=EventType.RUN_FLEET_API_LEAVE, grpc_servers=grpc_servers, bckg_threads=bckg_threads, - event_type=EventType.RUN_FLEET_API_LEAVE, ) # Block @@ -466,14 +361,23 @@ def run_superlink() -> None: certificates=certificates, ) grpc_servers.append(fleet_server) + elif args.fleet_api_type == TRANSPORT_TYPE_VCE: + _run_fleet_api_vce( + num_supernodes=args.num_supernodes, + client_app_module_name=args.client_app, + backend_name=args.backend, + backend_config_json_stream=args.backend_config, + working_dir=args.dir, + state_factory=state_factory, + ) else: raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}") # Graceful shutdown - _register_exit_handlers( + register_exit_handlers( + event_type=EventType.RUN_SUPERLINK_LEAVE, grpc_servers=grpc_servers, bckg_threads=bckg_threads, - event_type=EventType.RUN_SUPERLINK_LEAVE, ) # Block @@ -508,52 +412,6 @@ def _try_obtain_certificates( return certificates -def _register_exit_handlers( - grpc_servers: List[grpc.Server], - bckg_threads: List[threading.Thread], - event_type: EventType, -) -> None: - default_handlers = { - SIGINT: None, - SIGTERM: None, - } - - def graceful_exit_handler( # type: ignore - signalnum, - frame: FrameType, # pylint: disable=unused-argument - ) -> None: - """Exit handler to be registered with signal.signal. - - When called will reset signal handler to original signal handler from - default_handlers. - """ - # Reset to default handler - signal(signalnum, default_handlers[signalnum]) - - event_res = event(event_type=event_type) - - for grpc_server in grpc_servers: - grpc_server.stop(grace=1) - - for bckg_thread in bckg_threads: - bckg_thread.join() - - # Ensure event has happend - event_res.result() - - # Setup things for graceful exit - sys.exit(0) - - default_handlers[SIGINT] = signal( # type: ignore - SIGINT, - graceful_exit_handler, # type: ignore - ) - default_handlers[SIGTERM] = signal( # type: ignore - SIGTERM, - graceful_exit_handler, # type: ignore - ) - - def _run_driver_api_grpc( address: str, state_factory: StateFactory, @@ -602,6 +460,27 @@ def _run_fleet_api_grpc_rere( return fleet_grpc_server +# pylint: disable=too-many-arguments +def _run_fleet_api_vce( + num_supernodes: int, + client_app_module_name: str, + backend_name: str, + backend_config_json_stream: str, + working_dir: str, + state_factory: StateFactory, +) -> None: + log(INFO, "Flower VCE: Starting Fleet API (VirtualClientEngine)") + + start_vce( + num_supernodes=num_supernodes, + client_app_module_name=client_app_module_name, + backend_name=backend_name, + backend_config_json_stream=backend_config_json_stream, + state_factory=state_factory, + working_dir=working_dir, + ) + + # pylint: disable=import-outside-toplevel,too-many-arguments def _run_fleet_api_rest( host: str, @@ -779,6 +658,14 @@ def _add_args_fleet_api(parser: argparse.ArgumentParser) -> None: help="Start a Fleet API server (REST, experimental)", ) + ex_group.add_argument( + "--vce", + action="store_const", + dest="fleet_api_type", + const=TRANSPORT_TYPE_VCE, + help="Start a Fleet API server (VirtualClientEngine)", + ) + # Fleet API gRPC-rere options grpc_rere_group = parser.add_argument_group( "Fleet API (gRPC-rere) server options", "" @@ -815,41 +702,35 @@ def _add_args_fleet_api(parser: argparse.ArgumentParser) -> None: default=1, ) - -def _parse_args_run_server_app() -> argparse.ArgumentParser: - """Parse flower-server-app command line arguments.""" - parser = argparse.ArgumentParser( - description="Start a Flower server app", - ) - - parser.add_argument( - "server-app", - help="For example: `server:app` or `project.package.module:wrapper.app`", + # Fleet API VCE options + vce_group = parser.add_argument_group("Fleet API (VCE) server options", "") + vce_group.add_argument( + "--client-app", + help="For example: `client:app` or `project.package.module:wrapper.app`.", ) - parser.add_argument( - "--insecure", - action="store_true", - help="Run the server app without HTTPS. By default, the app runs with " - "HTTPS enabled. Use this flag only if you understand the risks.", + vce_group.add_argument( + "--num-supernodes", + type=int, + help="Number of simulated SuperNodes.", ) - parser.add_argument( - "--root-certificates", - metavar="ROOT_CERT", + vce_group.add_argument( + "--backend", + default="ray", type=str, - help="Specifies the path to the PEM-encoded root certificate file for " - "establishing secure HTTPS connections.", + help="Simulation backend that executes the ClientApp.", ) - parser.add_argument( - "--server", - default="0.0.0.0:9092", - help="Server address", + vce_group.add_argument( + "--backend-config", + type=str, + default='{"client_resources": {"num_cpus":1, "num_gpus":0.0}, "tensorflow": 0}', + help='A JSON formatted stream, e.g \'{"":, "":}\' to ' + "configure a backend. Values supported in are those included by " + "`flwr.common.typing.ConfigsRecordValues`. ", ) parser.add_argument( "--dir", default="", - help="Add specified directory to the PYTHONPATH and load Flower " - "app from there." + help="Add specified directory to the PYTHONPATH and load" + "ClientApp from there." " Default: current working directory.", ) - - return parser diff --git a/src/py/flwr/server/client_proxy.py b/src/py/flwr/server/client_proxy.py index 7d0547be304b..951e4ae992da 100644 --- a/src/py/flwr/server/client_proxy.py +++ b/src/py/flwr/server/client_proxy.py @@ -47,6 +47,7 @@ def get_properties( self, ins: GetPropertiesIns, timeout: Optional[float], + group_id: Optional[int], ) -> GetPropertiesRes: """Return the client's properties.""" @@ -55,6 +56,7 @@ def get_parameters( self, ins: GetParametersIns, timeout: Optional[float], + group_id: Optional[int], ) -> GetParametersRes: """Return the current local model parameters.""" @@ -63,6 +65,7 @@ def fit( self, ins: FitIns, timeout: Optional[float], + group_id: Optional[int], ) -> FitRes: """Refine the provided parameters using the locally held dataset.""" @@ -71,6 +74,7 @@ def evaluate( self, ins: EvaluateIns, timeout: Optional[float], + group_id: Optional[int], ) -> EvaluateRes: """Evaluate the provided parameters using the locally held dataset.""" @@ -79,5 +83,6 @@ def reconnect( self, ins: ReconnectIns, timeout: Optional[float], + group_id: Optional[int], ) -> DisconnectRes: """Disconnect and (optionally) reconnect later.""" diff --git a/src/py/flwr/server/client_proxy_test.py b/src/py/flwr/server/client_proxy_test.py index 266e4cbeb266..685698558e3a 100644 --- a/src/py/flwr/server/client_proxy_test.py +++ b/src/py/flwr/server/client_proxy_test.py @@ -42,6 +42,7 @@ def get_properties( self, ins: GetPropertiesIns, timeout: Optional[float], + group_id: Optional[int], ) -> GetPropertiesRes: """Return the client's properties.""" return GetPropertiesRes(status=Status(code=Code.OK, message=""), properties={}) @@ -50,6 +51,7 @@ def get_parameters( self, ins: GetParametersIns, timeout: Optional[float], + group_id: Optional[int], ) -> GetParametersRes: """Return the current local model parameters.""" return GetParametersRes( @@ -61,6 +63,7 @@ def fit( self, ins: FitIns, timeout: Optional[float], + group_id: Optional[int], ) -> FitRes: """Refine the provided weights using the locally held dataset.""" return FitRes( @@ -74,6 +77,7 @@ def evaluate( self, ins: EvaluateIns, timeout: Optional[float], + group_id: Optional[int], ) -> EvaluateRes: """Evaluate the provided weights using the locally held dataset.""" return EvaluateRes( @@ -84,6 +88,7 @@ def reconnect( self, ins: ReconnectIns, timeout: Optional[float], + group_id: Optional[int], ) -> DisconnectRes: """Disconnect and (optionally) reconnect later.""" return DisconnectRes(reason="") diff --git a/src/py/flwr/server/compat/__init__.py b/src/py/flwr/server/compat/__init__.py new file mode 100644 index 000000000000..7bae196ddb65 --- /dev/null +++ b/src/py/flwr/server/compat/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower ServerApp compatibility package.""" + + +from .app import start_driver as start_driver +from .legacy_context import LegacyContext as LegacyContext + +__all__ = [ + "LegacyContext", + "start_driver", +] diff --git a/src/py/flwr/server/driver/app.py b/src/py/flwr/server/compat/app.py similarity index 55% rename from src/py/flwr/server/driver/app.py rename to src/py/flwr/server/compat/app.py index ae47c58f4e9c..203317a3e348 100644 --- a/src/py/flwr/server/driver/app.py +++ b/src/py/flwr/server/compat/app.py @@ -16,24 +16,21 @@ import sys -import threading -import time from logging import INFO from pathlib import Path -from typing import Dict, Optional, Union +from typing import Optional, Union from flwr.common import EventType, event from flwr.common.address import parse_address -from flwr.common.logger import log -from flwr.proto import driver_pb2 # pylint: disable=E0611 -from flwr.server.app import ServerConfig, init_defaults, run_fl +from flwr.common.logger import log, warn_deprecated_feature from flwr.server.client_manager import ClientManager from flwr.server.history import History -from flwr.server.server import Server +from flwr.server.server import Server, init_defaults, run_fl +from flwr.server.server_config import ServerConfig from flwr.server.strategy import Strategy -from .driver_client_proxy import DriverClientProxy -from .grpc_driver import GrpcDriver +from ..driver import Driver +from .app_utils import start_update_client_manager_thread DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091" @@ -53,6 +50,7 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals strategy: Optional[Strategy] = None, client_manager: Optional[ClientManager] = None, root_certificates: Optional[Union[bytes, str]] = None, + driver: Optional[Driver] = None, ) -> History: """Start a Flower Driver API server. @@ -80,6 +78,8 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals The PEM-encoded root certificates as a byte string or a path string. If provided, a secure connection using the certificates will be established to an SSL-enabled Flower server. + driver : Optional[Driver] (default: None) + The Driver object to use. Returns ------- @@ -100,21 +100,23 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals """ event(EventType.START_DRIVER_ENTER) - # Parse IP address - parsed_address = parse_address(server_address) - if not parsed_address: - sys.exit(f"Server IP address ({server_address}) cannot be parsed.") - host, port, is_v6 = parsed_address - address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}" - - # Create the Driver - if isinstance(root_certificates, str): - root_certificates = Path(root_certificates).read_bytes() - driver = GrpcDriver( - driver_service_address=address, root_certificates=root_certificates - ) - driver.connect() - lock = threading.Lock() + if driver is None: + # Not passing a `Driver` object is deprecated + warn_deprecated_feature("start_driver") + + # Parse IP address + parsed_address = parse_address(server_address) + if not parsed_address: + sys.exit(f"Server IP address ({server_address}) cannot be parsed.") + host, port, is_v6 = parsed_address + address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}" + + # Create the Driver + if isinstance(root_certificates, str): + root_certificates = Path(root_certificates).read_bytes() + driver = Driver( + driver_service_address=address, root_certificates=root_certificates + ) # Initialize the Driver API server and config initialized_server, initialized_config = init_defaults( @@ -130,15 +132,9 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals ) # Start the thread updating nodes - thread = threading.Thread( - target=update_client_manager, - args=( - driver, - initialized_server.client_manager(), - lock, - ), + thread, f_stop = start_update_client_manager_thread( + driver, initialized_server.client_manager() ) - thread.start() # Start training hist = run_fl( @@ -146,68 +142,13 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals config=initialized_config, ) + f_stop.set() + # Stop the Driver API server and the thread - with lock: - driver.disconnect() + del driver + thread.join() event(EventType.START_SERVER_LEAVE) return hist - - -def update_client_manager( - driver: GrpcDriver, - client_manager: ClientManager, - lock: threading.Lock, -) -> None: - """Update the nodes list in the client manager. - - This function periodically communicates with the associated driver to get all - node_ids. Each node_id is then converted into a `DriverClientProxy` instance - and stored in the `registered_nodes` dictionary with node_id as key. - - New nodes will be added to the ClientManager via `client_manager.register()`, - and dead nodes will be removed from the ClientManager via - `client_manager.unregister()`. - """ - # Request for run_id - run_id = driver.create_run( - driver_pb2.CreateRunRequest() # pylint: disable=E1101 - ).run_id - - # Loop until the driver is disconnected - registered_nodes: Dict[int, DriverClientProxy] = {} - while True: - with lock: - # End the while loop if the driver is disconnected - if driver.stub is None: - break - get_nodes_res = driver.get_nodes( - req=driver_pb2.GetNodesRequest(run_id=run_id) # pylint: disable=E1101 - ) - all_node_ids = {node.node_id for node in get_nodes_res.nodes} - dead_nodes = set(registered_nodes).difference(all_node_ids) - new_nodes = all_node_ids.difference(registered_nodes) - - # Unregister dead nodes - for node_id in dead_nodes: - client_proxy = registered_nodes[node_id] - client_manager.unregister(client_proxy) - del registered_nodes[node_id] - - # Register new nodes - for node_id in new_nodes: - client_proxy = DriverClientProxy( - node_id=node_id, - driver=driver, - anonymous=False, - run_id=run_id, - ) - if client_manager.register(client_proxy): - registered_nodes[node_id] = client_proxy - else: - raise RuntimeError("Could not register node.") - - # Sleep for 3 seconds - time.sleep(3) diff --git a/src/py/flwr/server/compat/app_utils.py b/src/py/flwr/server/compat/app_utils.py new file mode 100644 index 000000000000..696ec1132c4a --- /dev/null +++ b/src/py/flwr/server/compat/app_utils.py @@ -0,0 +1,102 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility functions for the `start_driver`.""" + + +import threading +import time +from typing import Dict, Tuple + +from ..client_manager import ClientManager +from ..compat.driver_client_proxy import DriverClientProxy +from ..driver import Driver + + +def start_update_client_manager_thread( + driver: Driver, + client_manager: ClientManager, +) -> Tuple[threading.Thread, threading.Event]: + """Periodically update the nodes list in the client manager in a thread. + + This function starts a thread that periodically uses the associated driver to + get all node_ids. Each node_id is then converted into a `DriverClientProxy` + instance and stored in the `registered_nodes` dictionary with node_id as key. + + New nodes will be added to the ClientManager via `client_manager.register()`, + and dead nodes will be removed from the ClientManager via + `client_manager.unregister()`. + + Parameters + ---------- + driver : Driver + The Driver object to use. + client_manager : ClientManager + The ClientManager object to be updated. + + Returns + ------- + threading.Thread + A thread that updates the ClientManager and handles the stop event. + threading.Event + An event that, when set, signals the thread to stop. + """ + f_stop = threading.Event() + thread = threading.Thread( + target=_update_client_manager, + args=( + driver, + client_manager, + f_stop, + ), + ) + thread.start() + + return thread, f_stop + + +def _update_client_manager( + driver: Driver, + client_manager: ClientManager, + f_stop: threading.Event, +) -> None: + """Update the nodes list in the client manager.""" + # Loop until the driver is disconnected + registered_nodes: Dict[int, DriverClientProxy] = {} + while not f_stop.is_set(): + all_node_ids = set(driver.get_node_ids()) + dead_nodes = set(registered_nodes).difference(all_node_ids) + new_nodes = all_node_ids.difference(registered_nodes) + + # Unregister dead nodes + for node_id in dead_nodes: + client_proxy = registered_nodes[node_id] + client_manager.unregister(client_proxy) + del registered_nodes[node_id] + + # Register new nodes + for node_id in new_nodes: + client_proxy = DriverClientProxy( + node_id=node_id, + driver=driver.grpc_driver, # type: ignore + anonymous=False, + run_id=driver.run_id, # type: ignore + ) + if client_manager.register(client_proxy): + registered_nodes[node_id] = client_proxy + else: + raise RuntimeError("Could not register node.") + + # Sleep for 3 seconds + time.sleep(3) diff --git a/src/py/flwr/server/compat/app_utils_test.py b/src/py/flwr/server/compat/app_utils_test.py new file mode 100644 index 000000000000..7e47e6eaaf32 --- /dev/null +++ b/src/py/flwr/server/compat/app_utils_test.py @@ -0,0 +1,62 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for utility functions for the `start_driver`.""" + + +import time +import unittest +from unittest.mock import Mock, patch + +from ..client_manager import SimpleClientManager +from .app_utils import start_update_client_manager_thread + + +class TestUtils(unittest.TestCase): + """Tests for utility functions.""" + + def test_start_update_client_manager_thread(self) -> None: + """Test start_update_client_manager_thread function.""" + # Prepare + sleep = time.sleep + sleep_patch = patch("time.sleep", lambda x: sleep(x / 100)) + sleep_patch.start() + expected_node_ids = list(range(100)) + updated_expected_node_ids = list(range(80, 120)) + driver = Mock() + driver.grpc_driver = Mock() + driver.run_id = 123 + driver.get_node_ids.return_value = expected_node_ids + client_manager = SimpleClientManager() + + # Execute + thread, f_stop = start_update_client_manager_thread(driver, client_manager) + # Wait until all nodes are registered via `client_manager.sample()` + client_manager.sample(len(expected_node_ids)) + # Retrieve all nodes in `client_manager` + node_ids = {proxy.node_id for proxy in client_manager.all().values()} + # Update the GetNodesResponse and wait until the `client_manager` is updated + driver.get_node_ids.return_value = updated_expected_node_ids + sleep(0.1) + # Retrieve all nodes in `client_manager` + updated_node_ids = {proxy.node_id for proxy in client_manager.all().values()} + # Stop the thread + f_stop.set() + + # Assert + assert node_ids == set(expected_node_ids) + assert updated_node_ids == set(updated_expected_node_ids) + + # Exit + thread.join() diff --git a/src/py/flwr/server/driver/driver_client_proxy.py b/src/py/flwr/server/compat/driver_client_proxy.py similarity index 82% rename from src/py/flwr/server/driver/driver_client_proxy.py rename to src/py/flwr/server/compat/driver_client_proxy.py index e0ff26c035f7..46b2d92d9f04 100644 --- a/src/py/flwr/server/driver/driver_client_proxy.py +++ b/src/py/flwr/server/compat/driver_client_proxy.py @@ -19,19 +19,19 @@ from typing import List, Optional from flwr import common +from flwr.common import RecordSet from flwr.common import recordset_compat as compat from flwr.common import serde from flwr.common.constant import ( - TASK_TYPE_EVALUATE, - TASK_TYPE_FIT, - TASK_TYPE_GET_PARAMETERS, - TASK_TYPE_GET_PROPERTIES, + MESSAGE_TYPE_EVALUATE, + MESSAGE_TYPE_FIT, + MESSAGE_TYPE_GET_PARAMETERS, + MESSAGE_TYPE_GET_PROPERTIES, ) -from flwr.common.recordset import RecordSet from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611 from flwr.server.client_proxy import ClientProxy -from .grpc_driver import GrpcDriver +from ..driver.grpc_driver import GrpcDriver SLEEP_TIME = 1 @@ -47,57 +47,68 @@ def __init__(self, node_id: int, driver: GrpcDriver, anonymous: bool, run_id: in self.anonymous = anonymous def get_properties( - self, ins: common.GetPropertiesIns, timeout: Optional[float] + self, + ins: common.GetPropertiesIns, + timeout: Optional[float], + group_id: Optional[int], ) -> common.GetPropertiesRes: """Return client's properties.""" # Ins to RecordSet out_recordset = compat.getpropertiesins_to_recordset(ins) # Fetch response in_recordset = self._send_receive_recordset( - out_recordset, TASK_TYPE_GET_PROPERTIES, timeout + out_recordset, MESSAGE_TYPE_GET_PROPERTIES, timeout, group_id ) # RecordSet to Res return compat.recordset_to_getpropertiesres(in_recordset) def get_parameters( - self, ins: common.GetParametersIns, timeout: Optional[float] + self, + ins: common.GetParametersIns, + timeout: Optional[float], + group_id: Optional[int], ) -> common.GetParametersRes: """Return the current local model parameters.""" # Ins to RecordSet out_recordset = compat.getparametersins_to_recordset(ins) # Fetch response in_recordset = self._send_receive_recordset( - out_recordset, TASK_TYPE_GET_PARAMETERS, timeout + out_recordset, MESSAGE_TYPE_GET_PARAMETERS, timeout, group_id ) # RecordSet to Res return compat.recordset_to_getparametersres(in_recordset, False) - def fit(self, ins: common.FitIns, timeout: Optional[float]) -> common.FitRes: + def fit( + self, ins: common.FitIns, timeout: Optional[float], group_id: Optional[int] + ) -> common.FitRes: """Train model parameters on the locally held dataset.""" # Ins to RecordSet out_recordset = compat.fitins_to_recordset(ins, keep_input=True) # Fetch response in_recordset = self._send_receive_recordset( - out_recordset, TASK_TYPE_FIT, timeout + out_recordset, MESSAGE_TYPE_FIT, timeout, group_id ) # RecordSet to Res return compat.recordset_to_fitres(in_recordset, keep_input=False) def evaluate( - self, ins: common.EvaluateIns, timeout: Optional[float] + self, ins: common.EvaluateIns, timeout: Optional[float], group_id: Optional[int] ) -> common.EvaluateRes: """Evaluate model parameters on the locally held dataset.""" # Ins to RecordSet out_recordset = compat.evaluateins_to_recordset(ins, keep_input=True) # Fetch response in_recordset = self._send_receive_recordset( - out_recordset, TASK_TYPE_EVALUATE, timeout + out_recordset, MESSAGE_TYPE_EVALUATE, timeout, group_id ) # RecordSet to Res return compat.recordset_to_evaluateres(in_recordset) def reconnect( - self, ins: common.ReconnectIns, timeout: Optional[float] + self, + ins: common.ReconnectIns, + timeout: Optional[float], + group_id: Optional[int], ) -> common.DisconnectRes: """Disconnect and (optionally) reconnect later.""" return common.DisconnectRes(reason="") # Nothing to do here (yet) @@ -107,10 +118,11 @@ def _send_receive_recordset( recordset: RecordSet, task_type: str, timeout: Optional[float], + group_id: Optional[int], ) -> RecordSet: task_ins = task_pb2.TaskIns( # pylint: disable=E1101 task_id="", - group_id="", + group_id=str(group_id) if group_id is not None else "", run_id=self.run_id, task=task_pb2.Task( # pylint: disable=E1101 producer=node_pb2.Node( # pylint: disable=E1101 diff --git a/src/py/flwr/server/driver/driver_client_proxy_test.py b/src/py/flwr/server/compat/driver_client_proxy_test.py similarity index 91% rename from src/py/flwr/server/driver/driver_client_proxy_test.py rename to src/py/flwr/server/compat/driver_client_proxy_test.py index 18277a7ce80c..91f14553e5c2 100644 --- a/src/py/flwr/server/driver/driver_client_proxy_test.py +++ b/src/py/flwr/server/compat/driver_client_proxy_test.py @@ -25,10 +25,10 @@ from flwr.common import recordset_compat as compat from flwr.common import serde from flwr.common.constant import ( - TASK_TYPE_EVALUATE, - TASK_TYPE_FIT, - TASK_TYPE_GET_PARAMETERS, - TASK_TYPE_GET_PROPERTIES, + MESSAGE_TYPE_EVALUATE, + MESSAGE_TYPE_FIT, + MESSAGE_TYPE_GET_PARAMETERS, + MESSAGE_TYPE_GET_PROPERTIES, ) from flwr.common.typing import ( Code, @@ -44,7 +44,8 @@ Status, ) from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611 -from flwr.server.driver.driver_client_proxy import DriverClientProxy + +from .driver_client_proxy import DriverClientProxy MESSAGE_PARAMETERS = Parameters(tensors=[b"abc"], tensor_type="np") @@ -56,21 +57,21 @@ def _make_task( res: Union[GetParametersRes, GetPropertiesRes, FitRes, EvaluateRes] ) -> task_pb2.Task: # pylint: disable=E1101 if isinstance(res, GetParametersRes): - task_type = TASK_TYPE_GET_PARAMETERS + message_type = MESSAGE_TYPE_GET_PARAMETERS recordset = compat.getparametersres_to_recordset(res, True) elif isinstance(res, GetPropertiesRes): - task_type = TASK_TYPE_GET_PROPERTIES + message_type = MESSAGE_TYPE_GET_PROPERTIES recordset = compat.getpropertiesres_to_recordset(res) elif isinstance(res, FitRes): - task_type = TASK_TYPE_FIT + message_type = MESSAGE_TYPE_FIT recordset = compat.fitres_to_recordset(res, True) elif isinstance(res, EvaluateRes): - task_type = TASK_TYPE_EVALUATE + message_type = MESSAGE_TYPE_EVALUATE recordset = compat.evaluateres_to_recordset(res) else: raise ValueError(f"Unsupported type: {type(res)}") return task_pb2.Task( # pylint: disable=E1101 - task_type=task_type, + task_type=message_type, recordset=serde.recordset_to_proto(recordset), ) @@ -102,7 +103,7 @@ def test_get_properties(self) -> None: task_res_list=[ task_pb2.TaskRes( # pylint: disable=E1101 task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", - group_id="", + group_id=str(0), run_id=0, task=_make_task( GetPropertiesRes( @@ -122,7 +123,9 @@ def test_get_properties(self) -> None: ) # Execute - value: flwr.common.GetPropertiesRes = client.get_properties(ins, timeout=None) + value: flwr.common.GetPropertiesRes = client.get_properties( + ins, timeout=None, group_id=0 + ) # Assert assert value.properties["tensor_type"] == "numpy.ndarray" @@ -140,7 +143,7 @@ def test_get_parameters(self) -> None: task_res_list=[ task_pb2.TaskRes( # pylint: disable=E1101 task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", - group_id="", + group_id=str(0), run_id=0, task=_make_task( GetParametersRes( @@ -159,7 +162,7 @@ def test_get_parameters(self) -> None: # Execute value: flwr.common.GetParametersRes = client.get_parameters( - ins=get_parameters_ins, timeout=None + ins=get_parameters_ins, timeout=None, group_id=0 ) # Assert @@ -178,7 +181,7 @@ def test_fit(self) -> None: task_res_list=[ task_pb2.TaskRes( # pylint: disable=E1101 task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", - group_id="", + group_id=str(1), run_id=0, task=_make_task( FitRes( @@ -199,7 +202,7 @@ def test_fit(self) -> None: ins: flwr.common.FitIns = flwr.common.FitIns(parameters, {}) # Execute - fit_res = client.fit(ins=ins, timeout=None) + fit_res = client.fit(ins=ins, timeout=None, group_id=1) # Assert assert fit_res.parameters.tensor_type == "np" @@ -219,7 +222,7 @@ def test_evaluate(self) -> None: task_res_list=[ task_pb2.TaskRes( # pylint: disable=E1101 task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", - group_id="", + group_id=str(1), run_id=0, task=_make_task( EvaluateRes( @@ -240,7 +243,7 @@ def test_evaluate(self) -> None: evaluate_ins = EvaluateIns(parameters, {}) # Execute - evaluate_res = client.evaluate(evaluate_ins, timeout=None) + evaluate_res = client.evaluate(evaluate_ins, timeout=None, group_id=1) # Assert assert 0.0 == evaluate_res.loss diff --git a/src/py/flwr/server/compat/legacy_context.py b/src/py/flwr/server/compat/legacy_context.py new file mode 100644 index 000000000000..0b00c98bb16d --- /dev/null +++ b/src/py/flwr/server/compat/legacy_context.py @@ -0,0 +1,55 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Legacy Context.""" + + +from dataclasses import dataclass +from typing import Optional + +from flwr.common import Context, RecordSet + +from ..client_manager import ClientManager, SimpleClientManager +from ..history import History +from ..server_config import ServerConfig +from ..strategy import FedAvg, Strategy + + +@dataclass +class LegacyContext(Context): + """Legacy Context.""" + + config: ServerConfig + strategy: Strategy + client_manager: ClientManager + history: History + + def __init__( + self, + state: RecordSet, + config: Optional[ServerConfig] = None, + strategy: Optional[Strategy] = None, + client_manager: Optional[ClientManager] = None, + ) -> None: + if config is None: + config = ServerConfig() + if strategy is None: + strategy = FedAvg() + if client_manager is None: + client_manager = SimpleClientManager() + self.config = config + self.strategy = strategy + self.client_manager = client_manager + self.history = History() + super().__init__(state) diff --git a/src/py/flwr/server/driver/__init__.py b/src/py/flwr/server/driver/__init__.py index 1c3b09cc334b..b61f6eebf6a8 100644 --- a/src/py/flwr/server/driver/__init__.py +++ b/src/py/flwr/server/driver/__init__.py @@ -15,12 +15,10 @@ """Flower driver SDK.""" -from .app import start_driver from .driver import Driver from .grpc_driver import GrpcDriver __all__ = [ "Driver", "GrpcDriver", - "start_driver", ] diff --git a/src/py/flwr/server/driver/app_test.py b/src/py/flwr/server/driver/app_test.py deleted file mode 100644 index 03f490807876..000000000000 --- a/src/py/flwr/server/driver/app_test.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright 2022 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Flower Driver app tests.""" - - -import threading -import time -import unittest -from unittest.mock import MagicMock - -from flwr.proto.driver_pb2 import ( # pylint: disable=E0611 - CreateRunResponse, - GetNodesResponse, -) -from flwr.proto.node_pb2 import Node # pylint: disable=E0611 -from flwr.server.client_manager import SimpleClientManager -from flwr.server.driver.app import update_client_manager - - -class TestClientManagerWithDriver(unittest.TestCase): - """Tests for ClientManager. - - Considering multi-threading, all tests assume that the `update_client_manager()` - updates the ClientManager every 3 seconds. - """ - - def test_simple_client_manager_update(self) -> None: - """Tests if the node update works correctly.""" - # Prepare - expected_nodes = [Node(node_id=i, anonymous=False) for i in range(100)] - expected_updated_nodes = [ - Node(node_id=i, anonymous=False) for i in range(80, 120) - ] - driver = MagicMock() - driver.stub = "driver stub" - driver.create_run.return_value = CreateRunResponse(run_id=1) - driver.get_nodes.return_value = GetNodesResponse(nodes=expected_nodes) - client_manager = SimpleClientManager() - lock = threading.Lock() - - # Execute - thread = threading.Thread( - target=update_client_manager, - args=( - driver, - client_manager, - lock, - ), - daemon=True, - ) - thread.start() - # Wait until all nodes are registered via `client_manager.sample()` - client_manager.sample(len(expected_nodes)) - # Retrieve all nodes in `client_manager` - node_ids = {proxy.node_id for proxy in client_manager.all().values()} - # Update the GetNodesResponse and wait until the `client_manager` is updated - driver.get_nodes.return_value = GetNodesResponse(nodes=expected_updated_nodes) - while True: - with lock: - if len(client_manager.all()) == len(expected_updated_nodes): - break - time.sleep(1.3) - # Retrieve all nodes in `client_manager` - updated_node_ids = {proxy.node_id for proxy in client_manager.all().values()} - # Simulate `driver.disconnect()` - driver.stub = None - - # Assert - driver.create_run.assert_called_once() - assert node_ids == {node.node_id for node in expected_nodes} - assert updated_node_ids == {node.node_id for node in expected_updated_nodes} - - # Exit - thread.join() diff --git a/src/py/flwr/server/driver/driver.py b/src/py/flwr/server/driver/driver.py index 0a7cb36f8847..bcaac1f61b85 100644 --- a/src/py/flwr/server/driver/driver.py +++ b/src/py/flwr/server/driver/driver.py @@ -15,8 +15,11 @@ """Flower driver service client.""" +import time from typing import Iterable, List, Optional, Tuple +from flwr.common import Message, Metadata, RecordSet +from flwr.common.serde import message_from_taskres, message_to_taskins from flwr.proto.driver_pb2 import ( # pylint: disable=E0611 CreateRunRequest, GetNodesRequest, @@ -24,8 +27,9 @@ PushTaskInsRequest, ) from flwr.proto.node_pb2 import Node # pylint: disable=E0611 -from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 -from flwr.server.driver.grpc_driver import DEFAULT_SERVER_ADDRESS_DRIVER, GrpcDriver +from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611 + +from .grpc_driver import DEFAULT_SERVER_ADDRESS_DRIVER, GrpcDriver class Driver: @@ -68,44 +72,185 @@ def _get_grpc_driver_and_run_id(self) -> Tuple[GrpcDriver, int]: self.grpc_driver.connect() res = self.grpc_driver.create_run(CreateRunRequest()) self.run_id = res.run_id - return self.grpc_driver, self.run_id - def get_nodes(self) -> List[Node]: + def _check_message(self, message: Message) -> None: + # Check if the message is valid + if not ( + message.metadata.run_id == self.run_id + and message.metadata.src_node_id == self.node.node_id + and message.metadata.message_id == "" + and message.metadata.reply_to_message == "" + ): + raise ValueError(f"Invalid message: {message}") + + def create_message( # pylint: disable=too-many-arguments + self, + content: RecordSet, + message_type: str, + dst_node_id: int, + group_id: str, + ttl: str, + ) -> Message: + """Create a new message with specified parameters. + + This method constructs a new `Message` with given content and metadata. + The `run_id` and `src_node_id` will be set automatically. + + Parameters + ---------- + content : RecordSet + The content for the new message. This holds records that are to be sent + to the destination node. + message_type : str + The type of the message, defining the action to be executed on + the receiving end. + dst_node_id : int + The ID of the destination node to which the message is being sent. + group_id : str + The ID of the group to which this message is associated. In some settings, + this is used as the FL round. + ttl : str + Time-to-live for the round trip of this message, i.e., the time from sending + this message to receiving a reply. It specifies the duration for which the + message and its potential reply are considered valid. + + Returns + ------- + message : Message + A new `Message` instance with the specified content and metadata. + """ + _, run_id = self._get_grpc_driver_and_run_id() + metadata = Metadata( + run_id=run_id, + message_id="", # Will be set by the server + src_node_id=self.node.node_id, + dst_node_id=dst_node_id, + reply_to_message="", + group_id=group_id, + ttl=ttl, + message_type=message_type, + ) + return Message(metadata=metadata, content=content) + + def get_node_ids(self) -> List[int]: """Get node IDs.""" grpc_driver, run_id = self._get_grpc_driver_and_run_id() - # Call GrpcDriver method res = grpc_driver.get_nodes(GetNodesRequest(run_id=run_id)) - return list(res.nodes) + return [node.node_id for node in res.nodes] - def push_task_ins(self, task_ins_list: List[TaskIns]) -> List[str]: - """Schedule tasks.""" - grpc_driver, run_id = self._get_grpc_driver_and_run_id() + def push_messages(self, messages: Iterable[Message]) -> Iterable[str]: + """Push messages to specified node IDs. - # Set run_id - for task_ins in task_ins_list: - task_ins.run_id = run_id + This method takes an iterable of messages and sends each message + to the node specified in `dst_node_id`. + Parameters + ---------- + messages : Iterable[Message] + An iterable of messages to be sent. + + Returns + ------- + message_ids : Iterable[str] + An iterable of IDs for the messages that were sent, which can be used + to pull replies. + """ + grpc_driver, _ = self._get_grpc_driver_and_run_id() + # Construct TaskIns + task_ins_list: List[TaskIns] = [] + for msg in messages: + # Check message + self._check_message(msg) + # Convert Message to TaskIns + taskins = message_to_taskins(msg) + # Add to list + task_ins_list.append(taskins) # Call GrpcDriver method res = grpc_driver.push_task_ins(PushTaskInsRequest(task_ins_list=task_ins_list)) return list(res.task_ids) - def pull_task_res(self, task_ids: Iterable[str]) -> List[TaskRes]: - """Get task results.""" - grpc_driver, _ = self._get_grpc_driver_and_run_id() + def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]: + """Pull messages based on message IDs. - # Call GrpcDriver method + This method is used to collect messages from the SuperLink + that correspond to a set of given message IDs. + + Parameters + ---------- + message_ids : Iterable[str] + An iterable of message IDs for which reply messages are to be retrieved. + + Returns + ------- + messages : Iterable[Message] + An iterable of messages received. + """ + grpc_driver, _ = self._get_grpc_driver_and_run_id() + # Pull TaskRes res = grpc_driver.pull_task_res( - PullTaskResRequest(node=self.node, task_ids=task_ids) + PullTaskResRequest(node=self.node, task_ids=message_ids) ) - return list(res.task_res_list) + # Convert TaskRes to Message + msgs = [message_from_taskres(taskres) for taskres in res.task_res_list] + return msgs + + def send_and_receive( + self, + messages: Iterable[Message], + *, + timeout: Optional[float] = None, + ) -> Iterable[Message]: + """Push messages to specified node IDs and pull the reply messages. + + This method sends a list of messages to their destination node IDs and then + waits for the replies. It continues to pull replies until either all + replies are received or the specified timeout duration is exceeded. + + Parameters + ---------- + messages : Iterable[Message] + An iterable of messages to be sent. + timeout : Optional[float] (default: None) + The timeout duration in seconds. If specified, the method will wait for + replies for this duration. If `None`, there is no time limit and the method + will wait until replies for all messages are received. + + Returns + ------- + replies : Iterable[Message] + An iterable of reply messages received from the SuperLink. + + Notes + ----- + This method uses `push_messages` to send the messages and `pull_messages` + to collect the replies. If `timeout` is set, the method may not return + replies for all sent messages. A message remains valid until its TTL, + which is not affected by `timeout`. + """ + # Push messages + msg_ids = set(self.push_messages(messages)) + + # Pull messages + end_time = time.time() + (timeout if timeout is not None else 0.0) + ret: List[Message] = [] + while timeout is None or time.time() < end_time: + res_msgs = self.pull_messages(msg_ids) + ret.extend(res_msgs) + msg_ids.difference_update( + {msg.metadata.reply_to_message for msg in res_msgs} + ) + if len(msg_ids) == 0: + break + # Sleep + time.sleep(3) + return ret def __del__(self) -> None: """Disconnect GrpcDriver if connected.""" # Check if GrpcDriver is initialized if self.grpc_driver is None: return - # Disconnect self.grpc_driver.disconnect() diff --git a/src/py/flwr/server/driver/driver_test.py b/src/py/flwr/server/driver/driver_test.py index 0ee7fbfec37e..bd5c23a407fd 100644 --- a/src/py/flwr/server/driver/driver_test.py +++ b/src/py/flwr/server/driver/driver_test.py @@ -15,16 +15,19 @@ """Tests for driver SDK.""" +import time import unittest from unittest.mock import Mock, patch +from flwr.common import RecordSet from flwr.proto.driver_pb2 import ( # pylint: disable=E0611 GetNodesRequest, PullTaskResRequest, PushTaskInsRequest, ) -from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 -from flwr.server.driver.driver import Driver +from flwr.proto.task_pb2 import Task, TaskRes # pylint: disable=E0611 + +from .driver import Driver class TestDriver(unittest.TestCase): @@ -73,11 +76,11 @@ def test_get_nodes(self) -> None: """Test retrieval of nodes.""" # Prepare mock_response = Mock() - mock_response.nodes = [Mock(), Mock()] + mock_response.nodes = [Mock(node_id=404), Mock(node_id=200)] self.mock_grpc_driver.get_nodes.return_value = mock_response # Execute - nodes = self.driver.get_nodes() + node_ids = self.driver.get_node_ids() args, kwargs = self.mock_grpc_driver.get_nodes.call_args # Assert @@ -86,18 +89,19 @@ def test_get_nodes(self) -> None: self.assertEqual(len(kwargs), 0) self.assertIsInstance(args[0], GetNodesRequest) self.assertEqual(args[0].run_id, 61016) - self.assertEqual(nodes, mock_response.nodes) + self.assertEqual(node_ids, [404, 200]) - def test_push_task_ins(self) -> None: - """Test pushing task instructions.""" + def test_push_messages_valid(self) -> None: + """Test pushing valid messages.""" # Prepare - mock_response = Mock() - mock_response.task_ids = ["id1", "id2"] + mock_response = Mock(task_ids=["id1", "id2"]) self.mock_grpc_driver.push_task_ins.return_value = mock_response - task_ins_list = [TaskIns(), TaskIns()] + msgs = [ + self.driver.create_message(RecordSet(), "", 0, "", "") for _ in range(2) + ] # Execute - task_ids = self.driver.push_task_ins(task_ins_list) + msg_ids = self.driver.push_messages(msgs) args, kwargs = self.mock_grpc_driver.push_task_ins.call_args # Assert @@ -105,12 +109,27 @@ def test_push_task_ins(self) -> None: self.assertEqual(len(args), 1) self.assertEqual(len(kwargs), 0) self.assertIsInstance(args[0], PushTaskInsRequest) - self.assertEqual(task_ids, mock_response.task_ids) + self.assertEqual(msg_ids, mock_response.task_ids) for task_ins in args[0].task_ins_list: self.assertEqual(task_ins.run_id, 61016) - def test_pull_task_res_with_given_task_ids(self) -> None: - """Test pulling task results with specific task IDs.""" + def test_push_messages_invalid(self) -> None: + """Test pushing invalid messages.""" + # Prepare + mock_response = Mock(task_ids=["id1", "id2"]) + self.mock_grpc_driver.push_task_ins.return_value = mock_response + msgs = [ + self.driver.create_message(RecordSet(), "", 0, "", "") for _ in range(2) + ] + # Use invalid run_id + msgs[1].metadata._run_id += 1 # pylint: disable=protected-access + + # Execute and assert + with self.assertRaises(ValueError): + self.driver.push_messages(msgs) + + def test_pull_messages_with_given_message_ids(self) -> None: + """Test pulling messages with specific message IDs.""" # Prepare mock_response = Mock() mock_response.task_res_list = [ @@ -118,10 +137,11 @@ def test_pull_task_res_with_given_task_ids(self) -> None: TaskRes(task=Task(ancestry=["id3"])), ] self.mock_grpc_driver.pull_task_res.return_value = mock_response - task_ids = ["id1", "id2", "id3"] + msg_ids = ["id1", "id2", "id3"] # Execute - task_res_list = self.driver.pull_task_res(task_ids) + msgs = self.driver.pull_messages(msg_ids) + reply_tos = {msg.metadata.reply_to_message for msg in msgs} args, kwargs = self.mock_grpc_driver.pull_task_res.call_args # Assert @@ -129,8 +149,43 @@ def test_pull_task_res_with_given_task_ids(self) -> None: self.assertEqual(len(args), 1) self.assertEqual(len(kwargs), 0) self.assertIsInstance(args[0], PullTaskResRequest) - self.assertEqual(args[0].task_ids, task_ids) - self.assertEqual(task_res_list, mock_response.task_res_list) + self.assertEqual(args[0].task_ids, msg_ids) + self.assertEqual(reply_tos, {"id2", "id3"}) + + def test_send_and_receive_messages_complete(self) -> None: + """Test send and receive all messages successfully.""" + # Prepare + mock_response = Mock(task_ids=["id1"]) + self.mock_grpc_driver.push_task_ins.return_value = mock_response + mock_response = Mock(task_res_list=[TaskRes(task=Task(ancestry=["id1"]))]) + self.mock_grpc_driver.pull_task_res.return_value = mock_response + msgs = [self.driver.create_message(RecordSet(), "", 0, "", "")] + + # Execute + ret_msgs = list(self.driver.send_and_receive(msgs)) + + # Assert + self.assertEqual(len(ret_msgs), 1) + self.assertEqual(ret_msgs[0].metadata.reply_to_message, "id1") + + def test_send_and_receive_messages_timeout(self) -> None: + """Test send and receive messages but time out.""" + # Prepare + sleep_fn = time.sleep + mock_response = Mock(task_ids=["id1"]) + self.mock_grpc_driver.push_task_ins.return_value = mock_response + mock_response = Mock(task_res_list=[]) + self.mock_grpc_driver.pull_task_res.return_value = mock_response + msgs = [self.driver.create_message(RecordSet(), "", 0, "", "")] + + # Execute + with patch("time.sleep", side_effect=lambda t: sleep_fn(t * 0.01)): + start_time = time.time() + ret_msgs = list(self.driver.send_and_receive(msgs, timeout=0.15)) + + # Assert + self.assertLess(time.time() - start_time, 0.2) + self.assertEqual(len(ret_msgs), 0) def test_del_with_initialized_driver(self) -> None: """Test cleanup behavior when Driver is initialized.""" diff --git a/src/py/flwr/server/run_serverapp.py b/src/py/flwr/server/run_serverapp.py new file mode 100644 index 000000000000..19fd16fb0c1a --- /dev/null +++ b/src/py/flwr/server/run_serverapp.py @@ -0,0 +1,151 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Run ServerApp.""" + + +import argparse +import sys +from logging import DEBUG, WARN +from pathlib import Path + +from flwr.common import Context, EventType, RecordSet, event +from flwr.common.logger import log + +from .driver.driver import Driver +from .server_app import ServerApp, load_server_app + + +def run(server_app_attr: str, driver: Driver, server_app_dir: str) -> None: + """Run ServerApp with a given Driver.""" + if server_app_dir is not None: + sys.path.insert(0, server_app_dir) + + def _load() -> ServerApp: + server_app: ServerApp = load_server_app(server_app_attr) + return server_app + + server_app = _load() + + # Initialize Context + context = Context(state=RecordSet()) + + # Call ServerApp + server_app(driver=driver, context=context) + + +def run_server_app() -> None: + """Run Flower server app.""" + event(EventType.RUN_SERVER_APP_ENTER) + + args = _parse_args_run_server_app().parse_args() + + # Obtain certificates + if args.insecure: + if args.root_certificates is not None: + sys.exit( + "Conflicting options: The '--insecure' flag disables HTTPS, " + "but '--root-certificates' was also specified. Please remove " + "the '--root-certificates' option when running in insecure mode, " + "or omit '--insecure' to use HTTPS." + ) + log( + WARN, + "Option `--insecure` was set. " + "Starting insecure HTTP client connected to %s.", + args.server, + ) + root_certificates = None + else: + # Load the certificates if provided, or load the system certificates + cert_path = args.root_certificates + if cert_path is None: + root_certificates = None + else: + root_certificates = Path(cert_path).read_bytes() + log( + DEBUG, + "Starting secure HTTPS client connected to %s " + "with the following certificates: %s.", + args.server, + cert_path, + ) + + log( + DEBUG, + "Flower will load ServerApp `%s`", + getattr(args, "server-app"), + ) + + log( + DEBUG, + "root_certificates: `%s`", + root_certificates, + ) + + server_app_dir = args.dir + server_app_attr = getattr(args, "server-app") + + # Initialize Driver + driver = Driver( + driver_service_address=args.server, + root_certificates=root_certificates, + ) + + # Run the Server App with the Driver + run(server_app_attr, driver, server_app_dir) + + # Clean up + driver.__del__() # pylint: disable=unnecessary-dunder-call + + event(EventType.RUN_SERVER_APP_LEAVE) + + +def _parse_args_run_server_app() -> argparse.ArgumentParser: + """Parse flower-server-app command line arguments.""" + parser = argparse.ArgumentParser( + description="Start a Flower server app", + ) + + parser.add_argument( + "server-app", + help="For example: `server:app` or `project.package.module:wrapper.app`", + ) + parser.add_argument( + "--insecure", + action="store_true", + help="Run the server app without HTTPS. By default, the app runs with " + "HTTPS enabled. Use this flag only if you understand the risks.", + ) + parser.add_argument( + "--root-certificates", + metavar="ROOT_CERT", + type=str, + help="Specifies the path to the PEM-encoded root certificate file for " + "establishing secure HTTPS connections.", + ) + parser.add_argument( + "--server", + default="0.0.0.0:9091", + help="Server address", + ) + parser.add_argument( + "--dir", + default="", + help="Add specified directory to the PYTHONPATH and load Flower " + "app from there." + " Default: current working directory.", + ) + + return parser diff --git a/src/py/flwr/server/server.py b/src/py/flwr/server/server.py index cf3a4d9aa07c..dc7040193101 100644 --- a/src/py/flwr/server/server.py +++ b/src/py/flwr/server/server.py @@ -17,7 +17,7 @@ import concurrent.futures import timeit -from logging import DEBUG, INFO +from logging import DEBUG, INFO, WARN from typing import Dict, List, Optional, Tuple, Union from flwr.common import ( @@ -33,11 +33,13 @@ ) from flwr.common.logger import log from flwr.common.typing import GetParametersIns -from flwr.server.client_manager import ClientManager +from flwr.server.client_manager import ClientManager, SimpleClientManager from flwr.server.client_proxy import ClientProxy from flwr.server.history import History from flwr.server.strategy import FedAvg, Strategy +from .server_config import ServerConfig + FitResultsAndFailures = Tuple[ List[Tuple[ClientProxy, FitRes]], List[Union[Tuple[ClientProxy, FitRes], BaseException]], @@ -87,7 +89,7 @@ def fit(self, num_rounds: int, timeout: Optional[float]) -> History: # Initialize parameters log(INFO, "Initializing global parameters") - self.parameters = self._get_initial_parameters(timeout=timeout) + self.parameters = self._get_initial_parameters(server_round=0, timeout=timeout) log(INFO, "Evaluating initial parameters") res = self.strategy.evaluate(0, parameters=self.parameters) if res is not None: @@ -183,6 +185,7 @@ def evaluate_round( client_instructions, max_workers=self.max_workers, timeout=timeout, + group_id=server_round, ) log( DEBUG, @@ -232,6 +235,7 @@ def fit_round( client_instructions=client_instructions, max_workers=self.max_workers, timeout=timeout, + group_id=server_round, ) log( DEBUG, @@ -262,7 +266,9 @@ def disconnect_all_clients(self, timeout: Optional[float]) -> None: timeout=timeout, ) - def _get_initial_parameters(self, timeout: Optional[float]) -> Parameters: + def _get_initial_parameters( + self, server_round: int, timeout: Optional[float] + ) -> Parameters: """Get initial parameters from one of the available clients.""" # Server-side parameter initialization parameters: Optional[Parameters] = self.strategy.initialize_parameters( @@ -276,7 +282,9 @@ def _get_initial_parameters(self, timeout: Optional[float]) -> Parameters: log(INFO, "Requesting initial parameters from one random client") random_client = self._client_manager.sample(1)[0] ins = GetParametersIns(config={}) - get_parameters_res = random_client.get_parameters(ins=ins, timeout=timeout) + get_parameters_res = random_client.get_parameters( + ins=ins, timeout=timeout, group_id=server_round + ) log(INFO, "Received initial parameters from one random client") return get_parameters_res.parameters @@ -319,6 +327,7 @@ def reconnect_client( disconnect = client.reconnect( reconnect, timeout=timeout, + group_id=None, ) return client, disconnect @@ -327,11 +336,12 @@ def fit_clients( client_instructions: List[Tuple[ClientProxy, FitIns]], max_workers: Optional[int], timeout: Optional[float], + group_id: int, ) -> FitResultsAndFailures: """Refine parameters concurrently on all selected clients.""" with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: submitted_fs = { - executor.submit(fit_client, client_proxy, ins, timeout) + executor.submit(fit_client, client_proxy, ins, timeout, group_id) for client_proxy, ins in client_instructions } finished_fs, _ = concurrent.futures.wait( @@ -350,10 +360,10 @@ def fit_clients( def fit_client( - client: ClientProxy, ins: FitIns, timeout: Optional[float] + client: ClientProxy, ins: FitIns, timeout: Optional[float], group_id: int ) -> Tuple[ClientProxy, FitRes]: """Refine parameters on a single client.""" - fit_res = client.fit(ins, timeout=timeout) + fit_res = client.fit(ins, timeout=timeout, group_id=group_id) return client, fit_res @@ -386,11 +396,12 @@ def evaluate_clients( client_instructions: List[Tuple[ClientProxy, EvaluateIns]], max_workers: Optional[int], timeout: Optional[float], + group_id: int, ) -> EvaluateResultsAndFailures: """Evaluate parameters concurrently on all selected clients.""" with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: submitted_fs = { - executor.submit(evaluate_client, client_proxy, ins, timeout) + executor.submit(evaluate_client, client_proxy, ins, timeout, group_id) for client_proxy, ins in client_instructions } finished_fs, _ = concurrent.futures.wait( @@ -412,9 +423,10 @@ def evaluate_client( client: ClientProxy, ins: EvaluateIns, timeout: Optional[float], + group_id: int, ) -> Tuple[ClientProxy, EvaluateRes]: """Evaluate parameters on a single client.""" - evaluate_res = client.evaluate(ins, timeout=timeout) + evaluate_res = client.evaluate(ins, timeout=timeout, group_id=group_id) return client, evaluate_res @@ -441,3 +453,44 @@ def _handle_finished_future_after_evaluate( # Not successful, client returned a result where the status code is not OK failures.append(result) + + +def init_defaults( + server: Optional[Server], + config: Optional[ServerConfig], + strategy: Optional[Strategy], + client_manager: Optional[ClientManager], +) -> Tuple[Server, ServerConfig]: + """Create server instance if none was given.""" + if server is None: + if client_manager is None: + client_manager = SimpleClientManager() + if strategy is None: + strategy = FedAvg() + server = Server(client_manager=client_manager, strategy=strategy) + elif strategy is not None: + log(WARN, "Both server and strategy were provided, ignoring strategy") + + # Set default config values + if config is None: + config = ServerConfig() + + return server, config + + +def run_fl( + server: Server, + config: ServerConfig, +) -> History: + """Train a model on the given server and return the History object.""" + hist = server.fit(num_rounds=config.num_rounds, timeout=config.round_timeout) + log(INFO, "app_fit: losses_distributed %s", str(hist.losses_distributed)) + log(INFO, "app_fit: metrics_distributed_fit %s", str(hist.metrics_distributed_fit)) + log(INFO, "app_fit: metrics_distributed %s", str(hist.metrics_distributed)) + log(INFO, "app_fit: losses_centralized %s", str(hist.losses_centralized)) + log(INFO, "app_fit: metrics_centralized %s", str(hist.metrics_centralized)) + + # Graceful shutdown + server.disconnect_all_clients(timeout=config.round_timeout) + + return hist diff --git a/src/py/flwr/server/server_app.py b/src/py/flwr/server/server_app.py new file mode 100644 index 000000000000..7b5630b1bad2 --- /dev/null +++ b/src/py/flwr/server/server_app.py @@ -0,0 +1,179 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower ServerApp.""" + + +import importlib +from typing import Callable, Optional, cast + +from flwr.common import Context, RecordSet +from flwr.server.strategy import Strategy + +from .client_manager import ClientManager +from .compat import start_driver +from .driver import Driver +from .server import Server +from .server_config import ServerConfig +from .typing import ServerAppCallable + + +class ServerApp: + """Flower ServerApp. + + Examples + -------- + Use the `ServerApp` with an existing `Strategy`: + + >>> server_config = ServerConfig(num_rounds=3) + >>> strategy = FedAvg() + >>> + >>> app = ServerApp() + >>> server_config=server_config, + >>> strategy=strategy, + >>> ) + + Use the `ServerApp` with a custom main function: + + >>> app = ServerApp() + >>> + >>> @app.main() + >>> def main(driver: Driver, context: Context) -> None: + >>> print("ServerApp running") + """ + + def __init__( + self, + server: Optional[Server] = None, + config: Optional[ServerConfig] = None, + strategy: Optional[Strategy] = None, + client_manager: Optional[ClientManager] = None, + ) -> None: + self._server = server + self._config = config + self._strategy = strategy + self._client_manager = client_manager + self._main: Optional[ServerAppCallable] = None + + def __call__(self, driver: Driver, context: Context) -> None: + """Execute `ServerApp`.""" + # Compatibility mode + if not self._main: + start_driver( + server=self._server, + config=self._config, + strategy=self._strategy, + client_manager=self._client_manager, + driver=driver, + ) + return + + # New execution mode + context = Context(state=RecordSet()) + self._main(driver, context) + + def main(self) -> Callable[[ServerAppCallable], ServerAppCallable]: + """Return a decorator that registers the main fn with the server app. + + Examples + -------- + >>> app = ServerApp() + >>> + >>> @app.main() + >>> def main(driver: Driver, context: Context) -> None: + >>> print("ServerApp running") + """ + + def main_decorator(main_fn: ServerAppCallable) -> ServerAppCallable: + """Register the main fn with the ServerApp object.""" + if self._server or self._config or self._strategy or self._client_manager: + raise ValueError( + """Use either a custom main function or a `Strategy`, but not both. + + Use the `ServerApp` with an existing `Strategy`: + + >>> server_config = ServerConfig(num_rounds=3) + >>> strategy = FedAvg() + >>> + >>> app = ServerApp() + >>> server_config=server_config, + >>> strategy=strategy, + >>> ) + + Use the `ServerApp` with a custom main function: + + >>> app = ServerApp() + >>> + >>> @app.main() + >>> def main(driver: Driver, context: Context) -> None: + >>> print("ServerApp running") + """, + ) + + # Register provided function with the ServerApp object + self._main = main_fn + + # Return provided function unmodified + return main_fn + + return main_decorator + + +class LoadServerAppError(Exception): + """Error when trying to load `ServerApp`.""" + + +def load_server_app(module_attribute_str: str) -> ServerApp: + """Load the `ServerApp` object specified in a module attribute string. + + The module/attribute string should have the form :. Valid + examples include `server:app` and `project.package.module:wrapper.app`. It + must refer to a module on the PYTHONPATH, the module needs to have the specified + attribute, and the attribute must be of type `ServerApp`. + """ + module_str, _, attributes_str = module_attribute_str.partition(":") + if not module_str: + raise LoadServerAppError( + f"Missing module in {module_attribute_str}", + ) from None + if not attributes_str: + raise LoadServerAppError( + f"Missing attribute in {module_attribute_str}", + ) from None + + # Load module + try: + module = importlib.import_module(module_str) + except ModuleNotFoundError: + raise LoadServerAppError( + f"Unable to load module {module_str}", + ) from None + + # Recursively load attribute + attribute = module + try: + for attribute_str in attributes_str.split("."): + attribute = getattr(attribute, attribute_str) + except AttributeError: + raise LoadServerAppError( + f"Unable to load attribute {attributes_str} from module {module_str}", + ) from None + + # Check type + if not isinstance(attribute, ServerApp): + raise LoadServerAppError( + f"Attribute {attributes_str} is not of type {ServerApp}", + ) from None + + return cast(ServerApp, attribute) diff --git a/src/py/flwr/server/server_app_test.py b/src/py/flwr/server/server_app_test.py new file mode 100644 index 000000000000..38c0d6240d90 --- /dev/null +++ b/src/py/flwr/server/server_app_test.py @@ -0,0 +1,62 @@ +# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for ServerApp.""" + + +from unittest.mock import MagicMock + +import pytest + +from flwr.common import Context, RecordSet +from flwr.server import ServerApp, ServerConfig +from flwr.server.driver import Driver + + +def test_server_app_custom_mode() -> None: + """Test sampling w/o criterion.""" + # Prepare + app = ServerApp() + driver = MagicMock() + context = Context(state=RecordSet()) + + called = {"called": False} + + # pylint: disable=unused-argument + @app.main() + def custom_main(driver: Driver, context: Context) -> None: + called["called"] = True + + # pylint: enable=unused-argument + + # Execute + app(driver, context) + + # Assert + assert called["called"] + + +def test_server_app_exception_when_both_modes() -> None: + """Test ServerApp error when both compat mode and custom fns are used.""" + # Prepare + app = ServerApp(config=ServerConfig(num_rounds=3)) + + # Execute and assert + with pytest.raises(ValueError): + # pylint: disable=unused-argument + @app.main() + def custom_main(driver: Driver, context: Context) -> None: + pass + + # pylint: enable=unused-argument diff --git a/src/py/flwr/server/server_config.py b/src/py/flwr/server/server_config.py new file mode 100644 index 000000000000..823f832da6f8 --- /dev/null +++ b/src/py/flwr/server/server_config.py @@ -0,0 +1,31 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower ServerConfig.""" + + +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class ServerConfig: + """Flower server config. + + All attributes have default values which allows users to configure just the ones + they care about. + """ + + num_rounds: int = 1 + round_timeout: Optional[float] = None diff --git a/src/py/flwr/server/server_test.py b/src/py/flwr/server/server_test.py index 9b5c03aeeaf9..274e5289fee1 100644 --- a/src/py/flwr/server/server_test.py +++ b/src/py/flwr/server/server_test.py @@ -45,18 +45,20 @@ class SuccessClient(ClientProxy): """Test class.""" def get_properties( - self, ins: GetPropertiesIns, timeout: Optional[float] + self, ins: GetPropertiesIns, timeout: Optional[float], group_id: Optional[int] ) -> GetPropertiesRes: """Raise an error because this method is not expected to be called.""" raise NotImplementedError() def get_parameters( - self, ins: GetParametersIns, timeout: Optional[float] + self, ins: GetParametersIns, timeout: Optional[float], group_id: Optional[int] ) -> GetParametersRes: """Raise a error because this method is not expected to be called.""" raise NotImplementedError() - def fit(self, ins: FitIns, timeout: Optional[float]) -> FitRes: + def fit( + self, ins: FitIns, timeout: Optional[float], group_id: Optional[int] + ) -> FitRes: """Simulate fit by returning a success FitRes with simple set of weights.""" arr = np.array([[1, 2], [3, 4], [5, 6]]) arr_serialized = ndarray_to_bytes(arr) @@ -67,7 +69,9 @@ def fit(self, ins: FitIns, timeout: Optional[float]) -> FitRes: metrics={}, ) - def evaluate(self, ins: EvaluateIns, timeout: Optional[float]) -> EvaluateRes: + def evaluate( + self, ins: EvaluateIns, timeout: Optional[float], group_id: Optional[int] + ) -> EvaluateRes: """Simulate evaluate by returning a success EvaluateRes with loss 1.0.""" return EvaluateRes( status=Status(code=Code.OK, message="Success"), @@ -76,7 +80,9 @@ def evaluate(self, ins: EvaluateIns, timeout: Optional[float]) -> EvaluateRes: metrics={}, ) - def reconnect(self, ins: ReconnectIns, timeout: Optional[float]) -> DisconnectRes: + def reconnect( + self, ins: ReconnectIns, timeout: Optional[float], group_id: Optional[int] + ) -> DisconnectRes: """Simulate reconnect by returning a DisconnectRes with UNKNOWN reason.""" return DisconnectRes(reason="UNKNOWN") @@ -85,26 +91,32 @@ class FailingClient(ClientProxy): """Test class.""" def get_properties( - self, ins: GetPropertiesIns, timeout: Optional[float] + self, ins: GetPropertiesIns, timeout: Optional[float], group_id: Optional[int] ) -> GetPropertiesRes: """Raise a NotImplementedError to simulate failure in the client.""" raise NotImplementedError() def get_parameters( - self, ins: GetParametersIns, timeout: Optional[float] + self, ins: GetParametersIns, timeout: Optional[float], group_id: Optional[int] ) -> GetParametersRes: """Raise a NotImplementedError to simulate failure in the client.""" raise NotImplementedError() - def fit(self, ins: FitIns, timeout: Optional[float]) -> FitRes: + def fit( + self, ins: FitIns, timeout: Optional[float], group_id: Optional[int] + ) -> FitRes: """Raise a NotImplementedError to simulate failure in the client.""" raise NotImplementedError() - def evaluate(self, ins: EvaluateIns, timeout: Optional[float]) -> EvaluateRes: + def evaluate( + self, ins: EvaluateIns, timeout: Optional[float], group_id: Optional[int] + ) -> EvaluateRes: """Raise a NotImplementedError to simulate failure in the client.""" raise NotImplementedError() - def reconnect(self, ins: ReconnectIns, timeout: Optional[float]) -> DisconnectRes: + def reconnect( + self, ins: ReconnectIns, timeout: Optional[float], group_id: Optional[int] + ) -> DisconnectRes: """Raise a NotImplementedError to simulate failure in the client.""" raise NotImplementedError() @@ -122,7 +134,7 @@ def test_fit_clients() -> None: client_instructions = [(c, ins) for c in clients] # Execute - results, failures = fit_clients(client_instructions, None, None) + results, failures = fit_clients(client_instructions, None, None, 0) # Assert assert len(results) == 1 @@ -150,6 +162,7 @@ def test_eval_clients() -> None: client_instructions=client_instructions, max_workers=None, timeout=None, + group_id=0, ) # Assert diff --git a/src/py/flwr/server/strategy/__init__.py b/src/py/flwr/server/strategy/__init__.py index 1750a7522379..ffb55c12a60d 100644 --- a/src/py/flwr/server/strategy/__init__.py +++ b/src/py/flwr/server/strategy/__init__.py @@ -16,6 +16,11 @@ from .bulyan import Bulyan as Bulyan +from .dp_adaptive_clipping import DifferentialPrivacyClientSideAdaptiveClipping +from .dp_fixed_clipping import ( + DifferentialPrivacyClientSideFixedClipping, + DifferentialPrivacyServerSideFixedClipping, +) from .dpfedavg_adaptive import DPFedAvgAdaptive as DPFedAvgAdaptive from .dpfedavg_fixed import DPFedAvgFixed as DPFedAvgFixed from .fault_tolerant_fedavg import FaultTolerantFedAvg as FaultTolerantFedAvg @@ -37,24 +42,27 @@ from .strategy import Strategy as Strategy __all__ = [ - "FaultTolerantFedAvg", + "Bulyan", + "DPFedAvgAdaptive", + "DPFedAvgFixed", + "DifferentialPrivacyClientSideAdaptiveClipping", + "DifferentialPrivacyClientSideFixedClipping", + "DifferentialPrivacyServerSideFixedClipping", "FedAdagrad", "FedAdam", "FedAvg", - "FedXgbNnAvg", - "FedXgbBagging", - "FedXgbCyclic", "FedAvgAndroid", "FedAvgM", + "FedMedian", "FedOpt", "FedProx", - "FedYogi", - "QFedAvg", - "FedMedian", "FedTrimmedAvg", + "FedXgbBagging", + "FedXgbCyclic", + "FedXgbNnAvg", + "FedYogi", + "FaultTolerantFedAvg", "Krum", - "Bulyan", - "DPFedAvgAdaptive", - "DPFedAvgFixed", + "QFedAvg", "Strategy", ] diff --git a/src/py/flwr/server/strategy/dp_adaptive_clipping.py b/src/py/flwr/server/strategy/dp_adaptive_clipping.py new file mode 100644 index 000000000000..a5de29e74645 --- /dev/null +++ b/src/py/flwr/server/strategy/dp_adaptive_clipping.py @@ -0,0 +1,245 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Central differential privacy with adaptive clipping. + +Paper (Andrew et al.): https://arxiv.org/abs/1905.03871 +""" + + +import math +from logging import WARNING +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np + +from flwr.common import EvaluateIns, EvaluateRes, FitIns, FitRes, Parameters, Scalar +from flwr.common.differential_privacy import ( + add_gaussian_noise_to_params, + compute_adaptive_noise_params, +) +from flwr.common.differential_privacy_constants import ( + CLIENTS_DISCREPANCY_WARNING, + KEY_CLIPPING_NORM, + KEY_NORM_BIT, +) +from flwr.common.logger import log +from flwr.server.client_manager import ClientManager +from flwr.server.client_proxy import ClientProxy +from flwr.server.strategy.strategy import Strategy + + +class DifferentialPrivacyClientSideAdaptiveClipping(Strategy): + """Strategy wrapper for central DP with client-side adaptive clipping. + + Use `adaptiveclipping_mod` modifier at the client side. + + In comparison to `DifferentialPrivacyServerSideAdaptiveClipping`, + which performs clipping on the server-side, `DifferentialPrivacyClientSideAdaptiveClipping` + expects clipping to happen on the client-side, usually by using the built-in + `adaptiveclipping_mod`. + + Parameters + ---------- + strategy : Strategy + The strategy to which DP functionalities will be added by this wrapper. + noise_multiplier : float + The noise multiplier for the Gaussian mechanism for model updates. + num_sampled_clients : int + The number of clients that are sampled on each round. + initial_clipping_norm : float + The initial value of clipping norm. Deafults to 0.1. + Andrew et al. recommends to set to 0.1. + target_clipped_quantile : float + The desired quantile of updates which should be clipped. Defaults to 0.5. + clip_norm_lr : float + The learning rate for the clipping norm adaptation. Defaults to 0.2. + Andrew et al. recommends to set to 0.2. + clipped_count_stddev : float + The stddev of the noise added to the count of updates currently below the estimate. + Andrew et al. recommends to set to `expected_num_records/20` + + Examples + -------- + Create a strategy: + + >>> strategy = fl.server.strategy.FedAvg(...) + + Wrap the strategy with the `DifferentialPrivacyClientSideAdaptiveClipping` wrapper: + + >>> DifferentialPrivacyClientSideAdaptiveClipping( + >>> strategy, cfg.noise_multiplier, cfg.num_sampled_clients + >>> ) + + On the client, add the `adaptiveclipping_mod` to the client-side mods: + + >>> app = fl.client.ClientApp( + >>> client_fn=client_fn, mods=[adaptiveclipping_mod] + >>> ) + """ + + # pylint: disable=too-many-arguments,too-many-instance-attributes + def __init__( + self, + strategy: Strategy, + noise_multiplier: float, + num_sampled_clients: int, + initial_clipping_norm: float = 0.1, + target_clipped_quantile: float = 0.5, + clip_norm_lr: float = 0.2, + clipped_count_stddev: Optional[float] = None, + ) -> None: + super().__init__() + + if strategy is None: + raise ValueError("The passed strategy is None.") + + if noise_multiplier < 0: + raise ValueError("The noise multiplier should be a non-negative value.") + + if num_sampled_clients <= 0: + raise ValueError( + "The number of sampled clients should be a positive value." + ) + + if initial_clipping_norm <= 0: + raise ValueError("The initial clipping norm should be a positive value.") + + if not 0 <= target_clipped_quantile <= 1: + raise ValueError( + "The target clipped quantile must be between 0 and 1 (inclusive)." + ) + + if clip_norm_lr <= 0: + raise ValueError("The learning rate must be positive.") + + if clipped_count_stddev is not None and clipped_count_stddev < 0: + raise ValueError("The `clipped_count_stddev` must be non-negative.") + + self.strategy = strategy + self.num_sampled_clients = num_sampled_clients + self.clipping_norm = initial_clipping_norm + self.target_clipped_quantile = target_clipped_quantile + self.clip_norm_lr = clip_norm_lr + ( + self.clipped_count_stddev, + self.noise_multiplier, + ) = compute_adaptive_noise_params( + noise_multiplier, + num_sampled_clients, + clipped_count_stddev, + ) + + def __repr__(self) -> str: + """Compute a string representation of the strategy.""" + rep = "Differential Privacy Strategy Wrapper (Client-Side Adaptive Clipping)" + return rep + + def initialize_parameters( + self, client_manager: ClientManager + ) -> Optional[Parameters]: + """Initialize global model parameters using given strategy.""" + return self.strategy.initialize_parameters(client_manager) + + def configure_fit( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, FitIns]]: + """Configure the next round of training.""" + additional_config = {KEY_CLIPPING_NORM: self.clipping_norm} + inner_strategy_config_result = self.strategy.configure_fit( + server_round, parameters, client_manager + ) + for _, fit_ins in inner_strategy_config_result: + fit_ins.config.update(additional_config) + + return inner_strategy_config_result + + def configure_evaluate( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, EvaluateIns]]: + """Configure the next round of evaluation.""" + return self.strategy.configure_evaluate( + server_round, parameters, client_manager + ) + + def aggregate_fit( + self, + server_round: int, + results: List[Tuple[ClientProxy, FitRes]], + failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], + ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + """Aggregate training results and update clip norms.""" + if failures: + return None, {} + + if len(results) != self.num_sampled_clients: + log( + WARNING, + CLIENTS_DISCREPANCY_WARNING, + len(results), + self.num_sampled_clients, + ) + + aggregated_params, metrics = self.strategy.aggregate_fit( + server_round, results, failures + ) + self._update_clip_norm(results) + + # Add Gaussian noise to the aggregated parameters + if aggregated_params: + aggregated_params = add_gaussian_noise_to_params( + aggregated_params, + self.noise_multiplier, + self.clipping_norm, + self.num_sampled_clients, + ) + + return aggregated_params, metrics + + def _update_clip_norm(self, results: List[Tuple[ClientProxy, FitRes]]) -> None: + # Calculate the number of clients which set the norm indicator bit + norm_bit_set_count = 0 + for client_proxy, fit_res in results: + if KEY_NORM_BIT not in fit_res.metrics: + raise KeyError( + f"{KEY_NORM_BIT} not returned by client with id {client_proxy.cid}." + ) + if fit_res.metrics[KEY_NORM_BIT]: + norm_bit_set_count += 1 + # Add noise to the count + noised_norm_bit_set_count = float( + np.random.normal(norm_bit_set_count, self.clipped_count_stddev) + ) + + noised_norm_bit_set_fraction = noised_norm_bit_set_count / len(results) + # Geometric update + self.clipping_norm *= math.exp( + -self.clip_norm_lr + * (noised_norm_bit_set_fraction - self.target_clipped_quantile) + ) + + def aggregate_evaluate( + self, + server_round: int, + results: List[Tuple[ClientProxy, EvaluateRes]], + failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> Tuple[Optional[float], Dict[str, Scalar]]: + """Aggregate evaluation losses using the given strategy.""" + return self.strategy.aggregate_evaluate(server_round, results, failures) + + def evaluate( + self, server_round: int, parameters: Parameters + ) -> Optional[Tuple[float, Dict[str, Scalar]]]: + """Evaluate model parameters using an evaluation function from the strategy.""" + return self.strategy.evaluate(server_round, parameters) diff --git a/src/py/flwr/server/strategy/dp_fixed_clipping.py b/src/py/flwr/server/strategy/dp_fixed_clipping.py new file mode 100644 index 000000000000..69930ce49c0b --- /dev/null +++ b/src/py/flwr/server/strategy/dp_fixed_clipping.py @@ -0,0 +1,339 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Central differential privacy with fixed clipping. + +Papers: https://arxiv.org/abs/1712.07557, https://arxiv.org/abs/1710.06963 +""" + + +from logging import WARNING +from typing import Dict, List, Optional, Tuple, Union + +from flwr.common import ( + EvaluateIns, + EvaluateRes, + FitIns, + FitRes, + NDArrays, + Parameters, + Scalar, + ndarrays_to_parameters, + parameters_to_ndarrays, +) +from flwr.common.differential_privacy import ( + add_gaussian_noise_to_params, + compute_clip_model_update, +) +from flwr.common.differential_privacy_constants import ( + CLIENTS_DISCREPANCY_WARNING, + KEY_CLIPPING_NORM, +) +from flwr.common.logger import log +from flwr.server.client_manager import ClientManager +from flwr.server.client_proxy import ClientProxy +from flwr.server.strategy.strategy import Strategy + + +class DifferentialPrivacyServerSideFixedClipping(Strategy): + """Strategy wrapper for central DP with server-side fixed clipping. + + Parameters + ---------- + strategy : Strategy + The strategy to which DP functionalities will be added by this wrapper. + noise_multiplier : float + The noise multiplier for the Gaussian mechanism for model updates. + A value of 1.0 or higher is recommended for strong privacy. + clipping_norm : float + The value of the clipping norm. + num_sampled_clients : int + The number of clients that are sampled on each round. + + Examples + -------- + Create a strategy: + + >>> strategy = fl.server.strategy.FedAvg( ... ) + + Wrap the strategy with the DifferentialPrivacyServerSideFixedClipping wrapper + + >>> dp_strategy = DifferentialPrivacyServerSideFixedClipping( + >>> strategy, cfg.noise_multiplier, cfg.clipping_norm, cfg.num_sampled_clients + >>> ) + """ + + # pylint: disable=too-many-arguments,too-many-instance-attributes + def __init__( + self, + strategy: Strategy, + noise_multiplier: float, + clipping_norm: float, + num_sampled_clients: int, + ) -> None: + super().__init__() + + self.strategy = strategy + + if noise_multiplier < 0: + raise ValueError("The noise multiplier should be a non-negative value.") + + if clipping_norm <= 0: + raise ValueError("The clipping norm should be a positive value.") + + if num_sampled_clients <= 0: + raise ValueError( + "The number of sampled clients should be a positive value." + ) + + self.noise_multiplier = noise_multiplier + self.clipping_norm = clipping_norm + self.num_sampled_clients = num_sampled_clients + + self.current_round_params: NDArrays = [] + + def __repr__(self) -> str: + """Compute a string representation of the strategy.""" + rep = "Differential Privacy Strategy Wrapper (Server-Side Fixed Clipping)" + return rep + + def initialize_parameters( + self, client_manager: ClientManager + ) -> Optional[Parameters]: + """Initialize global model parameters using given strategy.""" + return self.strategy.initialize_parameters(client_manager) + + def configure_fit( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, FitIns]]: + """Configure the next round of training.""" + self.current_round_params = parameters_to_ndarrays(parameters) + return self.strategy.configure_fit(server_round, parameters, client_manager) + + def configure_evaluate( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, EvaluateIns]]: + """Configure the next round of evaluation.""" + return self.strategy.configure_evaluate( + server_round, parameters, client_manager + ) + + def aggregate_fit( + self, + server_round: int, + results: List[Tuple[ClientProxy, FitRes]], + failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], + ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + """Compute the updates, clip, and pass them for aggregation. + + Afterward, add noise to the aggregated parameters. + """ + if failures: + return None, {} + + if len(results) != self.num_sampled_clients: + log( + WARNING, + CLIENTS_DISCREPANCY_WARNING, + len(results), + self.num_sampled_clients, + ) + for _, res in results: + param = parameters_to_ndarrays(res.parameters) + # Compute and clip update + compute_clip_model_update( + param, self.current_round_params, self.clipping_norm + ) + # Convert back to parameters + res.parameters = ndarrays_to_parameters(param) + + # Pass the new parameters for aggregation + aggregated_params, metrics = self.strategy.aggregate_fit( + server_round, results, failures + ) + + # Add Gaussian noise to the aggregated parameters + if aggregated_params: + aggregated_params = add_gaussian_noise_to_params( + aggregated_params, + self.noise_multiplier, + self.clipping_norm, + self.num_sampled_clients, + ) + + return aggregated_params, metrics + + def aggregate_evaluate( + self, + server_round: int, + results: List[Tuple[ClientProxy, EvaluateRes]], + failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> Tuple[Optional[float], Dict[str, Scalar]]: + """Aggregate evaluation losses using the given strategy.""" + return self.strategy.aggregate_evaluate(server_round, results, failures) + + def evaluate( + self, server_round: int, parameters: Parameters + ) -> Optional[Tuple[float, Dict[str, Scalar]]]: + """Evaluate model parameters using an evaluation function from the strategy.""" + return self.strategy.evaluate(server_round, parameters) + + +class DifferentialPrivacyClientSideFixedClipping(Strategy): + """Strategy wrapper for central DP with client-side fixed clipping. + + Use `fixedclipping_mod` modifier at the client side. + + In comparison to `DifferentialPrivacyServerSideFixedClipping`, + which performs clipping on the server-side, `DifferentialPrivacyClientSideFixedClipping` + expects clipping to happen on the client-side, usually by using the built-in + `fixedclipping_mod`. + + Parameters + ---------- + strategy : Strategy + The strategy to which DP functionalities will be added by this wrapper. + noise_multiplier : float + The noise multiplier for the Gaussian mechanism for model updates. + A value of 1.0 or higher is recommended for strong privacy. + clipping_norm : float + The value of the clipping norm. + num_sampled_clients : int + The number of clients that are sampled on each round. + + Examples + -------- + Create a strategy: + + >>> strategy = fl.server.strategy.FedAvg(...) + + Wrap the strategy with the `DifferentialPrivacyClientSideFixedClipping` wrapper: + + >>> DifferentialPrivacyClientSideFixedClipping( + >>> strategy, cfg.noise_multiplier, cfg.clipping_norm, cfg.num_sampled_clients + >>> ) + + On the client, add the `fixedclipping_mod` to the client-side mods: + + >>> app = fl.client.ClientApp( + >>> client_fn=client_fn, mods=[fixedclipping_mod] + >>> ) + """ + + # pylint: disable=too-many-arguments,too-many-instance-attributes + def __init__( + self, + strategy: Strategy, + noise_multiplier: float, + clipping_norm: float, + num_sampled_clients: int, + ) -> None: + super().__init__() + + self.strategy = strategy + + if noise_multiplier < 0: + raise ValueError("The noise multiplier should be a non-negative value.") + + if clipping_norm <= 0: + raise ValueError("The clipping threshold should be a positive value.") + + if num_sampled_clients <= 0: + raise ValueError( + "The number of sampled clients should be a positive value." + ) + + self.noise_multiplier = noise_multiplier + self.clipping_norm = clipping_norm + self.num_sampled_clients = num_sampled_clients + + def __repr__(self) -> str: + """Compute a string representation of the strategy.""" + rep = "Differential Privacy Strategy Wrapper (Client-Side Fixed Clipping)" + return rep + + def initialize_parameters( + self, client_manager: ClientManager + ) -> Optional[Parameters]: + """Initialize global model parameters using given strategy.""" + return self.strategy.initialize_parameters(client_manager) + + def configure_fit( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, FitIns]]: + """Configure the next round of training.""" + additional_config = {KEY_CLIPPING_NORM: self.clipping_norm} + inner_strategy_config_result = self.strategy.configure_fit( + server_round, parameters, client_manager + ) + for _, fit_ins in inner_strategy_config_result: + fit_ins.config.update(additional_config) + + return inner_strategy_config_result + + def configure_evaluate( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, EvaluateIns]]: + """Configure the next round of evaluation.""" + return self.strategy.configure_evaluate( + server_round, parameters, client_manager + ) + + def aggregate_fit( + self, + server_round: int, + results: List[Tuple[ClientProxy, FitRes]], + failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], + ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + """Add noise to the aggregated parameters.""" + if failures: + return None, {} + + if len(results) != self.num_sampled_clients: + log( + WARNING, + CLIENTS_DISCREPANCY_WARNING, + len(results), + self.num_sampled_clients, + ) + + # Pass the new parameters for aggregation + aggregated_params, metrics = self.strategy.aggregate_fit( + server_round, results, failures + ) + + # Add Gaussian noise to the aggregated parameters + if aggregated_params: + aggregated_params = add_gaussian_noise_to_params( + aggregated_params, + self.noise_multiplier, + self.clipping_norm, + self.num_sampled_clients, + ) + return aggregated_params, metrics + + def aggregate_evaluate( + self, + server_round: int, + results: List[Tuple[ClientProxy, EvaluateRes]], + failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> Tuple[Optional[float], Dict[str, Scalar]]: + """Aggregate evaluation losses using the given strategy.""" + return self.strategy.aggregate_evaluate(server_round, results, failures) + + def evaluate( + self, server_round: int, parameters: Parameters + ) -> Optional[Tuple[float, Dict[str, Scalar]]]: + """Evaluate model parameters using an evaluation function from the strategy.""" + return self.strategy.evaluate(server_round, parameters) diff --git a/src/py/flwr/server/strategy/dpfedavg_fixed.py b/src/py/flwr/server/strategy/dpfedavg_fixed.py index b182ac26cef8..c54379fc7087 100644 --- a/src/py/flwr/server/strategy/dpfedavg_fixed.py +++ b/src/py/flwr/server/strategy/dpfedavg_fixed.py @@ -104,9 +104,9 @@ def configure_fit( """ additional_config = {"dpfedavg_clip_norm": self.clip_norm} if not self.server_side_noising: - additional_config[ - "dpfedavg_noise_stddev" - ] = self._calc_client_noise_stddev() + additional_config["dpfedavg_noise_stddev"] = ( + self._calc_client_noise_stddev() + ) client_instructions = self.strategy.configure_fit( server_round, parameters, client_manager diff --git a/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py index 026e8dfe51ef..ac62ad014950 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +++ b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py @@ -46,6 +46,7 @@ def get_properties( self, ins: common.GetPropertiesIns, timeout: Optional[float], + group_id: Optional[int], ) -> common.GetPropertiesRes: """Request client's set of internal properties.""" get_properties_msg = serde.get_properties_ins_to_proto(ins) @@ -65,6 +66,7 @@ def get_parameters( self, ins: common.GetParametersIns, timeout: Optional[float], + group_id: Optional[int], ) -> common.GetParametersRes: """Return the current local model parameters.""" get_parameters_msg = serde.get_parameters_ins_to_proto(ins) @@ -84,6 +86,7 @@ def fit( self, ins: common.FitIns, timeout: Optional[float], + group_id: Optional[int], ) -> common.FitRes: """Refine the provided parameters using the locally held dataset.""" fit_ins_msg = serde.fit_ins_to_proto(ins) @@ -102,6 +105,7 @@ def evaluate( self, ins: common.EvaluateIns, timeout: Optional[float], + group_id: Optional[int], ) -> common.EvaluateRes: """Evaluate the provided parameters using the locally held dataset.""" evaluate_msg = serde.evaluate_ins_to_proto(ins) @@ -119,6 +123,7 @@ def reconnect( self, ins: common.ReconnectIns, timeout: Optional[float], + group_id: Optional[int], ) -> common.DisconnectRes: """Disconnect and (optionally) reconnect later.""" reconnect_ins_msg = serde.reconnect_ins_to_proto(ins) diff --git a/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy_test.py b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy_test.py index 360570eb663d..e7077dfd39ae 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy_test.py +++ b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy_test.py @@ -71,7 +71,7 @@ def test_get_parameters(self) -> None: # Execute value: flwr.common.GetParametersRes = client.get_parameters( - ins=get_parameters_ins, timeout=None + ins=get_parameters_ins, timeout=None, group_id=0 ) # Assert @@ -88,7 +88,7 @@ def test_fit(self) -> None: ins: flwr.common.FitIns = flwr.common.FitIns(parameters, {}) # Execute - fit_res = client.fit(ins=ins, timeout=None) + fit_res = client.fit(ins=ins, timeout=None, group_id=0) # Assert assert fit_res.parameters.tensor_type == "np" @@ -106,7 +106,7 @@ def test_evaluate(self) -> None: evaluate_ins: flwr.common.EvaluateIns = flwr.common.EvaluateIns(parameters, {}) # Execute - evaluate_res = client.evaluate(evaluate_ins, timeout=None) + evaluate_res = client.evaluate(evaluate_ins, timeout=None, group_id=1) # Assert assert (0, 0.0) == ( @@ -127,7 +127,9 @@ def test_get_properties(self) -> None: ) # Execute - value: flwr.common.GetPropertiesRes = client.get_properties(ins, timeout=None) + value: flwr.common.GetPropertiesRes = client.get_properties( + ins, timeout=None, group_id=0 + ) # Assert assert value.properties["tensor_type"] == "numpy.ndarray" diff --git a/src/py/flwr/server/superlink/fleet/vce/__init__.py b/src/py/flwr/server/superlink/fleet/vce/__init__.py new file mode 100644 index 000000000000..72cd76f73761 --- /dev/null +++ b/src/py/flwr/server/superlink/fleet/vce/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Fleet VirtualClientEngine side.""" + +from .vce_api import start_vce + +__all__ = [ + "start_vce", +] diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/__init__.py b/src/py/flwr/server/superlink/fleet/vce/backend/__init__.py new file mode 100644 index 000000000000..d751cf4bcae1 --- /dev/null +++ b/src/py/flwr/server/superlink/fleet/vce/backend/__init__.py @@ -0,0 +1,48 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Simulation Engine Backends.""" + +import importlib +from typing import Dict, Type + +from .backend import Backend, BackendConfig + +is_ray_installed = importlib.util.find_spec("ray") is not None + +# Mapping of supported backends +supported_backends: Dict[str, Type[Backend]] = {} + +# To log backend-specific error message when chosen backend isn't available +error_messages_backends: Dict[str, str] = {} + +if is_ray_installed: + from .raybackend import RayBackend + + supported_backends["ray"] = RayBackend +else: + error_messages_backends[ + "ray" + ] = """Unable to import module `ray`. + + To install the necessary dependencies, install `flwr` with the `simulation` extra: + + pip install -U flwr["simulation"] + """ + + +__all__ = [ + "Backend", + "BackendConfig", +] diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/backend.py b/src/py/flwr/server/superlink/fleet/vce/backend/backend.py new file mode 100644 index 000000000000..1d5e3a6a51ad --- /dev/null +++ b/src/py/flwr/server/superlink/fleet/vce/backend/backend.py @@ -0,0 +1,67 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Generic Backend class for Fleet API using the Simulation Engine.""" + + +from abc import ABC, abstractmethod +from typing import Callable, Dict, Tuple + +from flwr.client.client_app import ClientApp +from flwr.common.context import Context +from flwr.common.message import Message +from flwr.common.typing import ConfigsRecordValues + +BackendConfig = Dict[str, Dict[str, ConfigsRecordValues]] + + +class Backend(ABC): + """Abstract base class for a Simulation Engine Backend.""" + + def __init__(self, backend_config: BackendConfig, work_dir: str) -> None: + """Construct a backend.""" + + @abstractmethod + async def build(self) -> None: + """Build backend asynchronously. + + Different components need to be in place before workers in a backend are ready + to accept jobs. When this method finishes executing, the backend should be fully + ready to run jobs. + """ + + @property + def num_workers(self) -> int: + """Return number of workers in the backend. + + This is the number of TaskIns that can be processed concurrently. + """ + return 0 + + @abstractmethod + def is_worker_idle(self) -> bool: + """Report whether a backend worker is idle and can therefore run a ClientApp.""" + + @abstractmethod + async def terminate(self) -> None: + """Terminate backend.""" + + @abstractmethod + async def process_message( + self, + app: Callable[[], ClientApp], + message: Message, + context: Context, + ) -> Tuple[Message, Context]: + """Submit a job to the backend.""" diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py new file mode 100644 index 000000000000..409deb077f1d --- /dev/null +++ b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py @@ -0,0 +1,176 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Ray backend for the Fleet API using the Simulation Engine.""" + +import pathlib +from logging import ERROR, INFO +from typing import Callable, Dict, List, Tuple, Union + +import ray + +from flwr.client.client_app import ClientApp, LoadClientAppError +from flwr.common.context import Context +from flwr.common.logger import log +from flwr.common.message import Message +from flwr.simulation.ray_transport.ray_actor import ( + BasicActorPool, + ClientAppActor, + init_ray, +) +from flwr.simulation.ray_transport.utils import enable_tf_gpu_growth + +from .backend import Backend, BackendConfig + +ClientResourcesDict = Dict[str, Union[int, float]] + + +class RayBackend(Backend): + """A backend that submits jobs to a `BasicActorPool`.""" + + def __init__( + self, + backend_config: BackendConfig, + work_dir: str, + ) -> None: + """Prepare RayBackend by initialising Ray and creating the ActorPool.""" + log(INFO, "Initialising: %s", self.__class__.__name__) + log(INFO, "Backend config: %s", backend_config) + + if not pathlib.Path(work_dir).exists(): + raise ValueError(f"Specified work_dir {work_dir} does not exist.") + + # Init ray and append working dir if needed + runtime_env = ( + self._configure_runtime_env(work_dir=work_dir) if work_dir else None + ) + init_ray(runtime_env=runtime_env) + + # Validate client resources + self.client_resources_key = "client_resources" + + # Create actor pool + use_tf = backend_config.get("tensorflow", False) + actor_kwargs = {"on_actor_init_fn": enable_tf_gpu_growth} if use_tf else {} + + client_resources = self._validate_client_resources(config=backend_config) + self.pool = BasicActorPool( + actor_type=ClientAppActor, + client_resources=client_resources, + actor_kwargs=actor_kwargs, + ) + + def _configure_runtime_env(self, work_dir: str) -> Dict[str, Union[str, List[str]]]: + """Return list of files/subdirectories to exclude relative to work_dir. + + Without this, Ray will push everything to the Ray Cluster. + """ + runtime_env: Dict[str, Union[str, List[str]]] = {"working_dir": work_dir} + + excludes = [] + path = pathlib.Path(work_dir) + for p in path.rglob("*"): + # Exclude files need to be relative to the working_dir + if p.is_file() and not str(p).endswith(".py"): + excludes.append(str(p.relative_to(path))) + runtime_env["excludes"] = excludes + + return runtime_env + + def _validate_client_resources(self, config: BackendConfig) -> ClientResourcesDict: + client_resources_config = config.get(self.client_resources_key) + client_resources: ClientResourcesDict = {} + valid_types = (int, float) + if client_resources_config: + for k, v in client_resources_config.items(): + if not isinstance(k, str): + raise ValueError( + f"client resources keys are expected to be `str` but you used " + f"{type(k)} for `{k}`" + ) + if not isinstance(v, valid_types): + raise ValueError( + f"client resources are expected to be of type {valid_types} " + f"but found `{type(v)}` for key `{k}`", + ) + client_resources[k] = v + + else: + client_resources = {"num_cpus": 2, "num_gpus": 0.0} + log( + INFO, + "`%s` not specified in backend config. Applying default setting: %s", + self.client_resources_key, + client_resources, + ) + + return client_resources + + @property + def num_workers(self) -> int: + """Return number of actors in pool.""" + return self.pool.num_actors + + def is_worker_idle(self) -> bool: + """Report whether the pool has idle actors.""" + return self.pool.is_actor_available() + + async def build(self) -> None: + """Build pool of Ray actors that this backend will submit jobs to.""" + await self.pool.add_actors_to_pool(self.pool.actors_capacity) + log(INFO, "Constructed ActorPool with: %i actors", self.pool.num_actors) + + async def process_message( + self, + app: Callable[[], ClientApp], + message: Message, + context: Context, + ) -> Tuple[Message, Context]: + """Run ClientApp that process a given message. + + Return output message and updated context. + """ + node_id = message.metadata.dst_node_id + + try: + # Submite a task to the pool + future = await self.pool.submit( + lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state), + (app, message, str(node_id), context), + ) + + await future + + # Fetch result + ( + out_mssg, + updated_context, + ) = await self.pool.fetch_result_and_return_actor_to_pool(future) + + return out_mssg, updated_context + + except LoadClientAppError as load_ex: + log( + ERROR, + "An exception was raised when processing a message. Terminating %s", + self.__class__.__name__, + ) + await self.terminate() + raise load_ex + + async def terminate(self) -> None: + """Terminate all actors in actor pool.""" + await self.pool.terminate_all_actors() + ray.shutdown() + log(INFO, "Terminated %s", self.__class__.__name__) diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py new file mode 100644 index 000000000000..fd246b5fc2af --- /dev/null +++ b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py @@ -0,0 +1,200 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test for Ray backend for the Fleet API using the Simulation Engine.""" + +import asyncio +from math import pi +from pathlib import Path +from typing import Callable, Dict, Optional, Tuple, Union +from unittest import IsolatedAsyncioTestCase + +from flwr.client import Client, NumPyClient +from flwr.client.client_app import ClientApp, LoadClientAppError, load_client_app +from flwr.common import ( + Config, + ConfigsRecord, + Context, + GetPropertiesIns, + Message, + Metadata, + RecordSet, + Scalar, +) +from flwr.common.constant import MESSAGE_TYPE_GET_PROPERTIES +from flwr.common.recordset_compat import getpropertiesins_to_recordset +from flwr.server.superlink.fleet.vce.backend.raybackend import RayBackend + + +class DummyClient(NumPyClient): + """A dummy NumPyClient for tests.""" + + def get_properties(self, config: Config) -> Dict[str, Scalar]: + """Return properties by doing a simple calculation.""" + result = float(config["factor"]) * pi + + # store something in context + self.context.state.configs_records["result"] = ConfigsRecord({"result": result}) + return {"result": result} + + +def get_dummy_client(cid: str) -> Client: # pylint: disable=unused-argument + """Return a DummyClient converted to Client type.""" + return DummyClient().to_client() + + +def _load_app() -> ClientApp: + return ClientApp(client_fn=get_dummy_client) + + +client_app = ClientApp( + client_fn=get_dummy_client, +) + + +def _load_from_module(client_app_module_name: str) -> Callable[[], ClientApp]: + def _load_app() -> ClientApp: + app: ClientApp = load_client_app(client_app_module_name) + return app + + return _load_app + + +async def backend_build_process_and_termination( + backend: RayBackend, + process_args: Optional[Tuple[Callable[[], ClientApp], Message, Context]] = None, +) -> Union[Tuple[Message, Context], None]: + """Build, process job and terminate RayBackend.""" + await backend.build() + to_return = None + + if process_args: + to_return = await backend.process_message(*process_args) + + await backend.terminate() + + return to_return + + +def _create_message_and_context() -> Tuple[Message, Context, float]: + + # Construct a Message + mult_factor = 2024 + getproperties_ins = GetPropertiesIns(config={"factor": mult_factor}) + recordset = getpropertiesins_to_recordset(getproperties_ins) + message = Message( + content=recordset, + metadata=Metadata( + run_id=0, + message_id="", + group_id="", + src_node_id=0, + dst_node_id=0, + reply_to_message="", + ttl="", + message_type=MESSAGE_TYPE_GET_PROPERTIES, + ), + ) + + # Construct emtpy Context + context = Context(state=RecordSet()) + + # Expected output + expected_output = pi * mult_factor + + return message, context, expected_output + + +class AsyncTestRayBackend(IsolatedAsyncioTestCase): + """A basic class that allows runnig multliple asyncio tests.""" + + def test_backend_creation_and_termination(self) -> None: + """Test creation of RayBackend and its termination.""" + backend = RayBackend(backend_config={}, work_dir="") + asyncio.run( + backend_build_process_and_termination(backend=backend, process_args=None) + ) + + def test_backend_creation_submit_and_termination( + self, + client_app_loader: Callable[[], ClientApp] = _load_app, + workdir: str = "", + ) -> None: + """Test submitting a message to a given ClientApp.""" + backend = RayBackend(backend_config={}, work_dir=workdir) + + # Define ClientApp + client_app_callable = client_app_loader + + message, context, expected_output = _create_message_and_context() + + res = asyncio.run( + backend_build_process_and_termination( + backend=backend, process_args=(client_app_callable, message, context) + ) + ) + + if res is None: + raise AssertionError("This shouldn't happen") + + out_mssg, updated_context = res + + # Verify message content is as expected + content = out_mssg.content + assert ( + content.configs_records["getpropertiesres.properties"]["result"] + == expected_output + ) + + # Verify context is correct + obtained_result_in_context = updated_context.state.configs_records["result"][ + "result" + ] + assert obtained_result_in_context == expected_output + + def test_backend_creation_submit_and_termination_non_existing_client_app( + self, + ) -> None: + """Testing with ClientApp module that does not exist.""" + with self.assertRaises(LoadClientAppError): + self.test_backend_creation_submit_and_termination( + client_app_loader=_load_from_module("a_non_existing_module:app") + ) + + def test_backend_creation_submit_and_termination_existing_client_app( + self, + ) -> None: + """Testing with ClientApp module that exist.""" + # Resolve what should be the workdir to pass upon Backend initialisation + file_path = Path(__file__) + working_dir = Path.cwd() + rel_workdir = file_path.relative_to(working_dir) + + # Susbtract last element + rel_workdir_str = str(rel_workdir.parent) + + self.test_backend_creation_submit_and_termination( + client_app_loader=_load_from_module("raybackend_test:client_app"), + workdir=rel_workdir_str, + ) + + def test_backend_creation_submit_and_termination_existing_client_app_unsetworkdir( + self, + ) -> None: + """Testing with ClientApp module that exist but the passed workdir does not.""" + with self.assertRaises(ValueError): + self.test_backend_creation_submit_and_termination( + client_app_loader=_load_from_module("raybackend_test:client_app"), + workdir="/?&%$^#%@$!", + ) diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api.py b/src/py/flwr/server/superlink/fleet/vce/vce_api.py new file mode 100644 index 000000000000..5d194632541e --- /dev/null +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api.py @@ -0,0 +1,92 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Fleet VirtualClientEngine API.""" + +import asyncio +import json +from logging import ERROR, INFO +from typing import Dict, Optional + +from flwr.client.client_app import ClientApp, load_client_app +from flwr.client.node_state import NodeState +from flwr.common.logger import log +from flwr.server.superlink.state import StateFactory + +from .backend import error_messages_backends, supported_backends + +NodeToPartitionMapping = Dict[int, int] + + +def _register_nodes( + num_nodes: int, state_factory: StateFactory +) -> NodeToPartitionMapping: + """Register nodes with the StateFactory and create node-id:partition-id mapping.""" + nodes_mapping: NodeToPartitionMapping = {} + state = state_factory.state() + for i in range(num_nodes): + node_id = state.create_node() + nodes_mapping[node_id] = i + log(INFO, "Registered %i nodes", len(nodes_mapping)) + return nodes_mapping + + +# pylint: disable=too-many-arguments,unused-argument +def start_vce( + num_supernodes: int, + client_app_module_name: str, + backend_name: str, + backend_config_json_stream: str, + state_factory: StateFactory, + working_dir: str, + f_stop: Optional[asyncio.Event] = None, +) -> None: + """Start Fleet API with the VirtualClientEngine (VCE).""" + # Register SuperNodes + nodes_mapping = _register_nodes( + num_nodes=num_supernodes, state_factory=state_factory + ) + + # Construct mapping of NodeStates + node_states: Dict[int, NodeState] = {} + for node_id in nodes_mapping: + node_states[node_id] = NodeState() + + # Load backend config + log(INFO, "Supported backends: %s", list(supported_backends.keys())) + backend_config = json.loads(backend_config_json_stream) + + try: + backend_type = supported_backends[backend_name] + _ = backend_type(backend_config, work_dir=working_dir) + except KeyError as ex: + log( + ERROR, + "Backend `%s`, is not supported. Use any of %s or add support " + "for a new backend.", + backend_name, + list(supported_backends.keys()), + ) + if backend_name in error_messages_backends: + log(ERROR, error_messages_backends[backend_name]) + + raise ex + + log(INFO, "client_app_module_name = %s", client_app_module_name) + + def _load() -> ClientApp: + app: ClientApp = load_client_app(client_app_module_name) + return app + + # start backend diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py index ecb39f18300a..690fadc032d7 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -16,6 +16,7 @@ import os +import threading from datetime import datetime, timedelta from logging import ERROR from typing import Dict, List, Optional, Set @@ -35,6 +36,7 @@ def __init__(self) -> None: self.run_ids: Set[int] = set() self.task_ins_store: Dict[UUID, TaskIns] = {} self.task_res_store: Dict[UUID, TaskRes] = {} + self.lock = threading.Lock() def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: """Store one TaskIns.""" @@ -57,7 +59,8 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: task_ins.task_id = str(task_id) task_ins.task.created_at = created_at.isoformat() task_ins.task.ttl = ttl.isoformat() - self.task_ins_store[task_id] = task_ins + with self.lock: + self.task_ins_store[task_id] = task_ins # Return the new task_id return task_id @@ -71,22 +74,23 @@ def get_task_ins( # Find TaskIns for node_id that were not delivered yet task_ins_list: List[TaskIns] = [] - for _, task_ins in self.task_ins_store.items(): - # pylint: disable=too-many-boolean-expressions - if ( - node_id is not None # Not anonymous - and task_ins.task.consumer.anonymous is False - and task_ins.task.consumer.node_id == node_id - and task_ins.task.delivered_at == "" - ) or ( - node_id is None # Anonymous - and task_ins.task.consumer.anonymous is True - and task_ins.task.consumer.node_id == 0 - and task_ins.task.delivered_at == "" - ): - task_ins_list.append(task_ins) - if limit and len(task_ins_list) == limit: - break + with self.lock: + for _, task_ins in self.task_ins_store.items(): + # pylint: disable=too-many-boolean-expressions + if ( + node_id is not None # Not anonymous + and task_ins.task.consumer.anonymous is False + and task_ins.task.consumer.node_id == node_id + and task_ins.task.delivered_at == "" + ) or ( + node_id is None # Anonymous + and task_ins.task.consumer.anonymous is True + and task_ins.task.consumer.node_id == 0 + and task_ins.task.delivered_at == "" + ): + task_ins_list.append(task_ins) + if limit and len(task_ins_list) == limit: + break # Mark all of them as delivered delivered_at = now().isoformat() @@ -164,7 +168,8 @@ def delete_tasks(self, task_ids: Set[UUID]) -> None: task_res_to_be_deleted.add(task_res_id) for task_id in task_ins_to_be_deleted: - del self.task_ins_store[task_id] + with self.lock: + del self.task_ins_store[task_id] for task_id in task_res_to_be_deleted: del self.task_res_store[task_id] diff --git a/src/py/flwr/server/typing.py b/src/py/flwr/server/typing.py new file mode 100644 index 000000000000..01143af74392 --- /dev/null +++ b/src/py/flwr/server/typing.py @@ -0,0 +1,25 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Custom types for Flower servers.""" + + +from typing import Callable + +from flwr.common import Context + +from .driver import Driver + +ServerAppCallable = Callable[[Driver, Context], None] +Workflow = Callable[[Driver, Context], None] diff --git a/src/py/flwr/server/workflow/__init__.py b/src/py/flwr/server/workflow/__init__.py new file mode 100644 index 000000000000..098b0dbfb92f --- /dev/null +++ b/src/py/flwr/server/workflow/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Workflows.""" + + +from .default_workflows import DefaultWorkflow + +__all__ = [ + "DefaultWorkflow", +] diff --git a/src/py/flwr/server/workflow/default_workflows.py b/src/py/flwr/server/workflow/default_workflows.py new file mode 100644 index 000000000000..5c6c1e2d114e --- /dev/null +++ b/src/py/flwr/server/workflow/default_workflows.py @@ -0,0 +1,357 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Legacy default workflows.""" + + +import timeit +from logging import DEBUG, INFO +from typing import Optional, cast + +import flwr.common.recordset_compat as compat +from flwr.common import ConfigsRecord, Context, GetParametersIns, log +from flwr.common.constant import ( + MESSAGE_TYPE_EVALUATE, + MESSAGE_TYPE_FIT, + MESSAGE_TYPE_GET_PARAMETERS, +) + +from ..compat.app_utils import start_update_client_manager_thread +from ..compat.legacy_context import LegacyContext +from ..driver import Driver +from ..typing import Workflow + +KEY_CURRENT_ROUND = "current_round" +KEY_START_TIME = "start_time" +CONFIGS_RECORD_KEY = "config" +PARAMS_RECORD_KEY = "parameters" + + +class DefaultWorkflow: + """Default workflow in Flower.""" + + def __init__( + self, + fit_workflow: Optional[Workflow] = None, + evaluate_workflow: Optional[Workflow] = None, + ) -> None: + if fit_workflow is None: + fit_workflow = default_fit_workflow + if evaluate_workflow is None: + evaluate_workflow = default_evaluate_workflow + self.fit_workflow: Workflow = fit_workflow + self.evaluate_workflow: Workflow = evaluate_workflow + + def __call__(self, driver: Driver, context: Context) -> None: + """Execute the workflow.""" + if not isinstance(context, LegacyContext): + raise TypeError( + f"Expect a LegacyContext, but get {type(context).__name__}." + ) + + # Start the thread updating nodes + thread, f_stop = start_update_client_manager_thread( + driver, context.client_manager + ) + + # Initialize parameters + default_init_params_workflow(driver, context) + + # Run federated learning for num_rounds + log(INFO, "FL starting") + start_time = timeit.default_timer() + cfg = ConfigsRecord() + cfg[KEY_START_TIME] = start_time + context.state.configs_records[CONFIGS_RECORD_KEY] = cfg + + for current_round in range(1, context.config.num_rounds + 1): + cfg[KEY_CURRENT_ROUND] = current_round + + # Fit round + self.fit_workflow(driver, context) + + # Centralized evaluation + default_centralized_evaluation_workflow(driver, context) + + # Evaluate round + self.evaluate_workflow(driver, context) + + # Bookkeeping + end_time = timeit.default_timer() + elapsed = end_time - start_time + log(INFO, "FL finished in %s", elapsed) + + # Log results + hist = context.history + log(INFO, "app_fit: losses_distributed %s", str(hist.losses_distributed)) + log( + INFO, + "app_fit: metrics_distributed_fit %s", + str(hist.metrics_distributed_fit), + ) + log(INFO, "app_fit: metrics_distributed %s", str(hist.metrics_distributed)) + log(INFO, "app_fit: losses_centralized %s", str(hist.losses_centralized)) + log(INFO, "app_fit: metrics_centralized %s", str(hist.metrics_centralized)) + + # Terminate the thread + f_stop.set() + del driver + thread.join() + + +def default_init_params_workflow(driver: Driver, context: Context) -> None: + """Execute the default workflow for parameters initialization.""" + if not isinstance(context, LegacyContext): + raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.") + + log(INFO, "Initializing global parameters") + parameters = context.strategy.initialize_parameters( + client_manager=context.client_manager + ) + if parameters is not None: + log(INFO, "Using initial parameters provided by strategy") + paramsrecord = compat.parameters_to_parametersrecord( + parameters, keep_input=True + ) + else: + # Get initial parameters from one of the clients + log(INFO, "Requesting initial parameters from one random client") + random_client = context.client_manager.sample(1)[0] + # Send GetParametersIns and get the response + content = compat.getparametersins_to_recordset(GetParametersIns({})) + messages = driver.send_and_receive( + [ + driver.create_message( + content=content, + message_type=MESSAGE_TYPE_GET_PARAMETERS, + dst_node_id=random_client.node_id, + group_id="", + ttl="", + ) + ] + ) + log(INFO, "Received initial parameters from one random client") + msg = list(messages)[0] + paramsrecord = next(iter(msg.content.parameters_records.values())) + + context.state.parameters_records[PARAMS_RECORD_KEY] = paramsrecord + + # Evaluate initial parameters + log(INFO, "Evaluating initial parameters") + parameters = compat.parametersrecord_to_parameters(paramsrecord, keep_input=True) + res = context.strategy.evaluate(0, parameters=parameters) + if res is not None: + log( + INFO, + "initial parameters (loss, other metrics): %s, %s", + res[0], + res[1], + ) + context.history.add_loss_centralized(server_round=0, loss=res[0]) + context.history.add_metrics_centralized(server_round=0, metrics=res[1]) + + +def default_centralized_evaluation_workflow(_: Driver, context: Context) -> None: + """Execute the default workflow for centralized evaluation.""" + if not isinstance(context, LegacyContext): + raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.") + + # Retrieve current_round and start_time from the context + cfg = context.state.configs_records[CONFIGS_RECORD_KEY] + current_round = cast(int, cfg[KEY_CURRENT_ROUND]) + start_time = cast(float, cfg[KEY_START_TIME]) + + # Centralized evaluation + parameters = compat.parametersrecord_to_parameters( + record=context.state.parameters_records[PARAMS_RECORD_KEY], + keep_input=True, + ) + res_cen = context.strategy.evaluate(current_round, parameters=parameters) + if res_cen is not None: + loss_cen, metrics_cen = res_cen + log( + INFO, + "fit progress: (%s, %s, %s, %s)", + current_round, + loss_cen, + metrics_cen, + timeit.default_timer() - start_time, + ) + context.history.add_loss_centralized(server_round=current_round, loss=loss_cen) + context.history.add_metrics_centralized( + server_round=current_round, metrics=metrics_cen + ) + + +def default_fit_workflow(driver: Driver, context: Context) -> None: + """Execute the default workflow for a single fit round.""" + if not isinstance(context, LegacyContext): + raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.") + + # Get current_round and parameters + cfg = context.state.configs_records[CONFIGS_RECORD_KEY] + current_round = cast(int, cfg[KEY_CURRENT_ROUND]) + parametersrecord = context.state.parameters_records[PARAMS_RECORD_KEY] + parameters = compat.parametersrecord_to_parameters( + parametersrecord, keep_input=True + ) + + # Get clients and their respective instructions from strategy + client_instructions = context.strategy.configure_fit( + server_round=current_round, + parameters=parameters, + client_manager=context.client_manager, + ) + + if not client_instructions: + log(INFO, "fit_round %s: no clients selected, cancel", current_round) + return + log( + DEBUG, + "fit_round %s: strategy sampled %s clients (out of %s)", + current_round, + len(client_instructions), + context.client_manager.num_available(), + ) + + # Build dictionary mapping node_id to ClientProxy + node_id_to_proxy = {proxy.node_id: proxy for proxy, _ in client_instructions} + + # Build out messages + out_messages = [ + driver.create_message( + content=compat.fitins_to_recordset(fitins, True), + message_type=MESSAGE_TYPE_FIT, + dst_node_id=proxy.node_id, + group_id="", + ttl="", + ) + for proxy, fitins in client_instructions + ] + + # Send instructions to clients and + # collect `fit` results from all clients participating in this round + messages = list(driver.send_and_receive(out_messages)) + del out_messages + + # No exception/failure handling currently + log( + DEBUG, + "fit_round %s received %s results and %s failures", + current_round, + len(messages), + 0, + ) + + # Aggregate training results + results = [ + ( + node_id_to_proxy[msg.metadata.src_node_id], + compat.recordset_to_fitres(msg.content, False), + ) + for msg in messages + ] + aggregated_result = context.strategy.aggregate_fit(current_round, results, []) + parameters_aggregated, metrics_aggregated = aggregated_result + + # Update the parameters and write history + if parameters_aggregated: + paramsrecord = compat.parameters_to_parametersrecord( + parameters_aggregated, True + ) + context.state.parameters_records[PARAMS_RECORD_KEY] = paramsrecord + context.history.add_metrics_distributed_fit( + server_round=current_round, metrics=metrics_aggregated + ) + + +def default_evaluate_workflow(driver: Driver, context: Context) -> None: + """Execute the default workflow for a single evaluate round.""" + if not isinstance(context, LegacyContext): + raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.") + + # Get current_round and parameters + cfg = context.state.configs_records[CONFIGS_RECORD_KEY] + current_round = cast(int, cfg[KEY_CURRENT_ROUND]) + parametersrecord = context.state.parameters_records[PARAMS_RECORD_KEY] + parameters = compat.parametersrecord_to_parameters( + parametersrecord, keep_input=True + ) + + # Get clients and their respective instructions from strategy + client_instructions = context.strategy.configure_evaluate( + server_round=current_round, + parameters=parameters, + client_manager=context.client_manager, + ) + if not client_instructions: + log(INFO, "evaluate_round %s: no clients selected, cancel", current_round) + return + log( + DEBUG, + "evaluate_round %s: strategy sampled %s clients (out of %s)", + current_round, + len(client_instructions), + context.client_manager.num_available(), + ) + + # Build dictionary mapping node_id to ClientProxy + node_id_to_proxy = {proxy.node_id: proxy for proxy, _ in client_instructions} + + # Build out messages + out_messages = [ + driver.create_message( + content=compat.evaluateins_to_recordset(evalins, True), + message_type=MESSAGE_TYPE_EVALUATE, + dst_node_id=proxy.node_id, + group_id="", + ttl="", + ) + for proxy, evalins in client_instructions + ] + + # Send instructions to clients and + # collect `evaluate` results from all clients participating in this round + messages = list(driver.send_and_receive(out_messages)) + del out_messages + + # No exception/failure handling currently + log( + DEBUG, + "evaluate_round %s received %s results and %s failures", + current_round, + len(messages), + 0, + ) + + # Aggregate the evaluation results + results = [ + ( + node_id_to_proxy[msg.metadata.src_node_id], + compat.recordset_to_evaluateres(msg.content), + ) + for msg in messages + ] + aggregated_result = context.strategy.aggregate_evaluate(current_round, results, []) + + loss_aggregated, metrics_aggregated = aggregated_result + + # Write history + if loss_aggregated is not None: + context.history.add_loss_distributed( + server_round=current_round, loss=loss_aggregated + ) + context.history.add_metrics_distributed( + server_round=current_round, metrics=metrics_aggregated + ) diff --git a/src/py/flwr/simulation/app.py b/src/py/flwr/simulation/app.py index 6a18a258ac60..ff18f37664be 100644 --- a/src/py/flwr/simulation/app.py +++ b/src/py/flwr/simulation/app.py @@ -28,13 +28,13 @@ from flwr.client import ClientFn from flwr.common import EventType, event from flwr.common.logger import log -from flwr.server import Server -from flwr.server.app import ServerConfig, init_defaults, run_fl from flwr.server.client_manager import ClientManager from flwr.server.history import History +from flwr.server.server import Server, init_defaults, run_fl +from flwr.server.server_config import ServerConfig from flwr.server.strategy import Strategy from flwr.simulation.ray_transport.ray_actor import ( - DefaultActor, + ClientAppActor, VirtualClientEngineActor, VirtualClientEngineActorPool, pool_size_from_resources, @@ -82,7 +82,7 @@ def start_simulation( client_manager: Optional[ClientManager] = None, ray_init_args: Optional[Dict[str, Any]] = None, keep_initialised: Optional[bool] = False, - actor_type: Type[VirtualClientEngineActor] = DefaultActor, + actor_type: Type[VirtualClientEngineActor] = ClientAppActor, actor_kwargs: Optional[Dict[str, Any]] = None, actor_scheduling: Union[str, NodeAffinitySchedulingStrategy] = "DEFAULT", ) -> History: @@ -138,10 +138,10 @@ def start_simulation( keep_initialised: Optional[bool] (default: False) Set to True to prevent `ray.shutdown()` in case `ray.is_initialized()=True`. - actor_type: VirtualClientEngineActor (default: DefaultActor) + actor_type: VirtualClientEngineActor (default: ClientAppActor) Optionally specify the type of actor to use. The actor object, which persists throughout the simulation, will be the process in charge of - running the clients' jobs (i.e. their `fit()` method). + executing a ClientApp wrapping input argument `client_fn`. actor_kwargs: Optional[Dict[str, Any]] (default: None) If you want to create your own Actor classes, you might need to pass @@ -219,7 +219,7 @@ def start_simulation( log( INFO, "Optimize your simulation with Flower VCE: " - "https://flower.dev/docs/framework/how-to-run-simulations.html", + "https://flower.ai/docs/framework/how-to-run-simulations.html", ) # Log the resources that a single client will be able to use @@ -337,7 +337,7 @@ def update_resources(f_stop: threading.Event) -> None: "disconnected. The head node might still be alive but cannot accommodate " "any actor with resources: %s." "\nTake a look at the Flower simulation examples for guidance " - ".", + ".", client_resources, client_resources, ) diff --git a/src/py/flwr/simulation/ray_transport/ray_actor.py b/src/py/flwr/simulation/ray_transport/ray_actor.py index 853566a4cbeb..08d0576e39f0 100644 --- a/src/py/flwr/simulation/ray_transport/ray_actor.py +++ b/src/py/flwr/simulation/ray_transport/ray_actor.py @@ -14,29 +14,22 @@ # ============================================================================== """Ray-based Flower Actor and ActorPool implementation.""" - +import asyncio import threading import traceback from abc import ABC -from logging import ERROR, WARNING +from logging import DEBUG, ERROR, WARNING from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union import ray from ray import ObjectRef from ray.util.actor_pool import ActorPool -from flwr import common -from flwr.client import Client, ClientFn -from flwr.common.context import Context +from flwr.client.client_app import ClientApp, LoadClientAppError +from flwr.common import Context, Message from flwr.common.logger import log -from flwr.simulation.ray_transport.utils import check_clientfn_returns_client -# All possible returns by a client -ClientRes = Union[ - common.GetPropertiesRes, common.GetParametersRes, common.FitRes, common.EvaluateRes -] -# A function to be executed by a client to obtain some results -JobFn = Callable[[Client], ClientRes] +ClientAppFn = Callable[[], ClientApp] class ClientException(Exception): @@ -53,32 +46,33 @@ class VirtualClientEngineActor(ABC): def terminate(self) -> None: """Manually terminate Actor object.""" - log(WARNING, "Manually terminating %s}", self.__class__.__name__) + log(WARNING, "Manually terminating %s", self.__class__.__name__) ray.actor.exit_actor() def run( self, - client_fn: ClientFn, - job_fn: JobFn, + client_app_fn: ClientAppFn, + message: Message, cid: str, context: Context, - ) -> Tuple[str, ClientRes, Context]: + ) -> Tuple[str, Message, Context]: """Run a client run.""" - # Execute tasks and return result + # Pass message through ClientApp and return a message # return also cid which is needed to ensure results # from the pool are correctly assigned to each ClientProxy try: - # Instantiate client (check 'Client' type is returned) - client = check_clientfn_returns_client(client_fn(cid)) - # Inject context - client.set_context(context) - # Run client job - job_results = job_fn(client) - # Retrieve context (potentially updated) - updated_context = client.get_context() + # Load app + app: ClientApp = client_app_fn() + + # Handle task message + out_message = app(message=message, context=context) + + except LoadClientAppError as load_ex: + raise load_ex + except Exception as ex: client_trace = traceback.format_exc() - message = ( + mssg = ( "\n\tSomething went wrong when running your client run." "\n\tClient " + cid @@ -87,13 +81,13 @@ def run( + " was running its run." "\n\tException triggered on the client side: " + client_trace, ) - raise ClientException(str(message)) from ex + raise ClientException(str(mssg)) from ex - return cid, job_results, updated_context + return cid, out_message, context @ray.remote -class DefaultActor(VirtualClientEngineActor): +class ClientAppActor(VirtualClientEngineActor): """A Ray Actor class that runs client runs. Parameters @@ -237,16 +231,16 @@ def add_actors_to_pool(self, num_actors: int) -> None: self._idle_actors.extend(new_actors) self.num_actors += num_actors - def submit(self, fn: Any, value: Tuple[ClientFn, JobFn, str, Context]) -> None: - """Take idle actor and assign it a client run. + def submit(self, fn: Any, value: Tuple[ClientAppFn, Message, str, Context]) -> None: + """Take an idle actor and assign it to run a client app and Message. Submit a job to an actor by first removing it from the list of idle actors, then - check if this actor was flagged to be removed from the pool + check if this actor was flagged to be removed from the pool. """ - client_fn, job_fn, cid, context = value + app_fn, mssg, cid, context = value actor = self._idle_actors.pop() if self._check_and_remove_actor_from_pool(actor): - future = fn(actor, client_fn, job_fn, cid, context) + future = fn(actor, app_fn, mssg, cid, context) future_key = tuple(future) if isinstance(future, List) else future self._future_to_actor[future_key] = (self._next_task_index, actor, cid) self._next_task_index += 1 @@ -255,7 +249,7 @@ def submit(self, fn: Any, value: Tuple[ClientFn, JobFn, str, Context]) -> None: self._cid_to_future[cid]["future"] = future_key def submit_client_job( - self, actor_fn: Any, job: Tuple[ClientFn, JobFn, str, Context] + self, actor_fn: Any, job: Tuple[ClientAppFn, Message, str, Context] ) -> None: """Submit a job while tracking client ids.""" _, _, cid, _ = job @@ -295,7 +289,7 @@ def _is_future_ready(self, cid: str) -> bool: return self._cid_to_future[cid]["ready"] # type: ignore - def _fetch_future_result(self, cid: str) -> Tuple[ClientRes, Context]: + def _fetch_future_result(self, cid: str) -> Tuple[Message, Context]: """Fetch result and updated context for a VirtualClient from Object Store. The job submitted by the ClientProxy interfacing with client with cid=cid is @@ -303,9 +297,9 @@ def _fetch_future_result(self, cid: str) -> Tuple[ClientRes, Context]: """ try: future: ObjectRef[Any] = self._cid_to_future[cid]["future"] # type: ignore - res_cid, res, updated_context = ray.get( + res_cid, out_mssg, updated_context = ray.get( future - ) # type: (str, ClientRes, Context) + ) # type: (str, Message, Context) except ray.exceptions.RayActorError as ex: log(ERROR, ex) if hasattr(ex, "actor_id"): @@ -322,7 +316,7 @@ def _fetch_future_result(self, cid: str) -> Tuple[ClientRes, Context]: # Reset mapping self._reset_cid_to_future_dict(cid) - return res, updated_context + return out_mssg, updated_context def _flag_actor_for_removal(self, actor_id_hex: str) -> None: """Flag actor that should be removed from pool.""" @@ -409,7 +403,7 @@ def process_unordered_future(self, timeout: Optional[float] = None) -> None: def get_client_result( self, cid: str, timeout: Optional[float] - ) -> Tuple[ClientRes, Context]: + ) -> Tuple[Message, Context]: """Get result from VirtualClient with specific cid.""" # Loop until all jobs submitted to the pool are completed. Break early # if the result for the ClientProxy calling this method is ready @@ -423,3 +417,90 @@ def get_client_result( # Fetch result belonging to the VirtualClient calling this method # Return both result from tasks and (potentially) updated run context return self._fetch_future_result(cid) + + +def init_ray(*args: Any, **kwargs: Any) -> None: + """Intialises Ray if not already initialised.""" + if not ray.is_initialized(): + ray.init(*args, **kwargs) + + +class BasicActorPool: + """A basic actor pool.""" + + def __init__( + self, + actor_type: Type[VirtualClientEngineActor], + client_resources: Dict[str, Union[int, float]], + actor_kwargs: Dict[str, Any], + ): + self.client_resources = client_resources + + # Queue of idle actors + self.pool: "asyncio.Queue[Type[VirtualClientEngineActor]]" = asyncio.Queue( + maxsize=1024 + ) + self.num_actors = 0 + + # Resolve arguments to pass during actor init + actor_args = {} if actor_kwargs is None else actor_kwargs + + # A function that creates an actor + self.create_actor_fn = lambda: actor_type.options( # type: ignore + **client_resources + ).remote(**actor_args) + + # Figure out how many actors can be created given the cluster resources + # and the resources the user indicates each VirtualClient will need + self.actors_capacity = pool_size_from_resources(client_resources) + self._future_to_actor: Dict[Any, Type[VirtualClientEngineActor]] = {} + + def is_actor_available(self) -> bool: + """Return true if there is an idle actor.""" + return self.pool.qsize() > 0 + + async def add_actors_to_pool(self, num_actors: int) -> None: + """Add actors to the pool. + + This method may be executed also if new resources are added to your Ray cluster + (e.g. you add a new node). + """ + for _ in range(num_actors): + await self.pool.put(self.create_actor_fn()) # type: ignore + self.num_actors += num_actors + + async def terminate_all_actors(self) -> None: + """Terminate actors in pool.""" + num_terminated = 0 + while self.pool.qsize(): + actor = await self.pool.get() + actor.terminate.remote() # type: ignore + num_terminated += 1 + + log(DEBUG, "Terminated %i actors", num_terminated) + + async def submit( + self, actor_fn: Any, job: Tuple[ClientAppFn, Message, str, Context] + ) -> Any: + """On idle actor, submit job and return future.""" + # Remove idle actor from pool + actor = await self.pool.get() + # Submit job to actor + app_fn, mssg, cid, context = job + future = actor_fn(actor, app_fn, mssg, cid, context) + # Keep track of future:actor (so we can fetch the actor upon job completion + # and add it back to the pool) + self._future_to_actor[future] = actor + return future + + async def fetch_result_and_return_actor_to_pool( + self, future: Any + ) -> Tuple[Message, Context]: + """Pull result given a future and add actor back to pool.""" + # Get actor that ran job + actor = self._future_to_actor.pop(future) + await self.pool.put(actor) + # Retrieve result for object store + # Instead of doing ray.get(future) we await it + _, out_mssg, updated_context = await future + return out_mssg, updated_context diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py index 894012dc6d70..ba31a69af8ee 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py @@ -17,107 +17,32 @@ import traceback from logging import ERROR -from typing import Dict, Optional, cast - -import ray +from typing import Optional from flwr import common -from flwr.client import Client, ClientFn -from flwr.client.client import ( - maybe_call_evaluate, - maybe_call_fit, - maybe_call_get_parameters, - maybe_call_get_properties, -) +from flwr.client import ClientFn +from flwr.client.client_app import ClientApp from flwr.client.node_state import NodeState +from flwr.common import Message, Metadata, RecordSet +from flwr.common.constant import ( + MESSAGE_TYPE_EVALUATE, + MESSAGE_TYPE_FIT, + MESSAGE_TYPE_GET_PARAMETERS, + MESSAGE_TYPE_GET_PROPERTIES, +) from flwr.common.logger import log -from flwr.server.client_proxy import ClientProxy -from flwr.simulation.ray_transport.ray_actor import ( - ClientRes, - JobFn, - VirtualClientEngineActorPool, +from flwr.common.recordset_compat import ( + evaluateins_to_recordset, + fitins_to_recordset, + getparametersins_to_recordset, + getpropertiesins_to_recordset, + recordset_to_evaluateres, + recordset_to_fitres, + recordset_to_getparametersres, + recordset_to_getpropertiesres, ) - - -class RayClientProxy(ClientProxy): - """Flower client proxy which delegates work using Ray.""" - - def __init__(self, client_fn: ClientFn, cid: str, resources: Dict[str, float]): - super().__init__(cid) - self.client_fn = client_fn - self.resources = resources - - def get_properties( - self, ins: common.GetPropertiesIns, timeout: Optional[float] - ) -> common.GetPropertiesRes: - """Return client's properties.""" - future_get_properties_res = launch_and_get_properties.options( # type: ignore - **self.resources, - ).remote(self.client_fn, self.cid, ins) - try: - res = ray.get(future_get_properties_res, timeout=timeout) - except Exception as ex: - log(ERROR, ex) - raise ex - return cast( - common.GetPropertiesRes, - res, - ) - - def get_parameters( - self, ins: common.GetParametersIns, timeout: Optional[float] - ) -> common.GetParametersRes: - """Return the current local model parameters.""" - future_paramseters_res = launch_and_get_parameters.options( # type: ignore - **self.resources, - ).remote(self.client_fn, self.cid, ins) - try: - res = ray.get(future_paramseters_res, timeout=timeout) - except Exception as ex: - log(ERROR, ex) - raise ex - return cast( - common.GetParametersRes, - res, - ) - - def fit(self, ins: common.FitIns, timeout: Optional[float]) -> common.FitRes: - """Train model parameters on the locally held dataset.""" - future_fit_res = launch_and_fit.options( # type: ignore - **self.resources, - ).remote(self.client_fn, self.cid, ins) - try: - res = ray.get(future_fit_res, timeout=timeout) - except Exception as ex: - log(ERROR, ex) - raise ex - return cast( - common.FitRes, - res, - ) - - def evaluate( - self, ins: common.EvaluateIns, timeout: Optional[float] - ) -> common.EvaluateRes: - """Evaluate model parameters on the locally held dataset.""" - future_evaluate_res = launch_and_evaluate.options( # type: ignore - **self.resources, - ).remote(self.client_fn, self.cid, ins) - try: - res = ray.get(future_evaluate_res, timeout=timeout) - except Exception as ex: - log(ERROR, ex) - raise ex - return cast( - common.EvaluateRes, - res, - ) - - def reconnect( - self, ins: common.ReconnectIns, timeout: Optional[float] - ) -> common.DisconnectRes: - """Disconnect and (optionally) reconnect later.""" - return common.DisconnectRes(reason="") # Nothing to do here (yet) +from flwr.server.client_proxy import ClientProxy +from flwr.simulation.ray_transport.ray_actor import VirtualClientEngineActorPool class RayActorClientProxy(ClientProxy): @@ -127,15 +52,17 @@ def __init__( self, client_fn: ClientFn, cid: str, actor_pool: VirtualClientEngineActorPool ): super().__init__(cid) - self.client_fn = client_fn + + def _load_app() -> ClientApp: + return ClientApp(client_fn=client_fn) + + self.app_fn = _load_app self.actor_pool = actor_pool self.proxy_state = NodeState() - def _submit_job(self, job_fn: JobFn, timeout: Optional[float]) -> ClientRes: - # The VCE is not exposed to TaskIns, it won't handle multilple runs - # For the time being, fixing run_id is a small compromise - # This will be one of the first points to address integrating VCE + DriverAPI - run_id = 0 + def _submit_job(self, message: Message, timeout: Optional[float]) -> Message: + """Sumbit a message to the ActorPool.""" + run_id = message.metadata.run_id # Register state self.proxy_state.register_context(run_id=run_id) @@ -145,10 +72,12 @@ def _submit_job(self, job_fn: JobFn, timeout: Optional[float]) -> ClientRes: try: self.actor_pool.submit_client_job( - lambda a, c_fn, j_fn, cid, state: a.run.remote(c_fn, j_fn, cid, state), - (self.client_fn, job_fn, self.cid, state), + lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state), + (self.app_fn, message, self.cid, state), + ) + out_mssg, updated_context = self.actor_pool.get_client_result( + self.cid, timeout ) - res, updated_context = self.actor_pool.get_client_result(self.cid, timeout) # Update state self.proxy_state.update_context(run_id=run_id, context=updated_context) @@ -162,134 +91,110 @@ def _submit_job(self, job_fn: JobFn, timeout: Optional[float]) -> ClientRes: log(ERROR, ex) raise ex - return res + return out_mssg + + def _wrap_recordset_in_message( + self, + recordset: RecordSet, + message_type: str, + timeout: Optional[float], + group_id: Optional[int], + ) -> Message: + """Wrap a RecordSet inside a Message.""" + return Message( + content=recordset, + metadata=Metadata( + run_id=0, + message_id="", + group_id=str(group_id) if group_id is not None else "", + src_node_id=0, + dst_node_id=int(self.cid), + reply_to_message="", + ttl=str(timeout) if timeout else "", + message_type=message_type, + partition_id=int(self.cid), + ), + ) def get_properties( - self, ins: common.GetPropertiesIns, timeout: Optional[float] + self, + ins: common.GetPropertiesIns, + timeout: Optional[float], + group_id: Optional[int], ) -> common.GetPropertiesRes: """Return client's properties.""" + recordset = getpropertiesins_to_recordset(ins) + message = self._wrap_recordset_in_message( + recordset, + message_type=MESSAGE_TYPE_GET_PROPERTIES, + timeout=timeout, + group_id=group_id, + ) - def get_properties(client: Client) -> common.GetPropertiesRes: - return maybe_call_get_properties( - client=client, - get_properties_ins=ins, - ) - - res = self._submit_job(get_properties, timeout) + message_out = self._submit_job(message, timeout) - return cast( - common.GetPropertiesRes, - res, - ) + return recordset_to_getpropertiesres(message_out.content) def get_parameters( - self, ins: common.GetParametersIns, timeout: Optional[float] + self, + ins: common.GetParametersIns, + timeout: Optional[float], + group_id: Optional[int], ) -> common.GetParametersRes: """Return the current local model parameters.""" + recordset = getparametersins_to_recordset(ins) + message = self._wrap_recordset_in_message( + recordset, + message_type=MESSAGE_TYPE_GET_PARAMETERS, + timeout=timeout, + group_id=group_id, + ) - def get_parameters(client: Client) -> common.GetParametersRes: - return maybe_call_get_parameters( - client=client, - get_parameters_ins=ins, - ) + message_out = self._submit_job(message, timeout) - res = self._submit_job(get_parameters, timeout) + return recordset_to_getparametersres(message_out.content, keep_input=False) - return cast( - common.GetParametersRes, - res, - ) - - def fit(self, ins: common.FitIns, timeout: Optional[float]) -> common.FitRes: + def fit( + self, ins: common.FitIns, timeout: Optional[float], group_id: Optional[int] + ) -> common.FitRes: """Train model parameters on the locally held dataset.""" + recordset = fitins_to_recordset( + ins, keep_input=True + ) # This must stay TRUE since ins are in-memory + message = self._wrap_recordset_in_message( + recordset, + message_type=MESSAGE_TYPE_FIT, + timeout=timeout, + group_id=group_id, + ) - def fit(client: Client) -> common.FitRes: - return maybe_call_fit( - client=client, - fit_ins=ins, - ) - - res = self._submit_job(fit, timeout) + message_out = self._submit_job(message, timeout) - return cast( - common.FitRes, - res, - ) + return recordset_to_fitres(message_out.content, keep_input=False) def evaluate( - self, ins: common.EvaluateIns, timeout: Optional[float] + self, ins: common.EvaluateIns, timeout: Optional[float], group_id: Optional[int] ) -> common.EvaluateRes: """Evaluate model parameters on the locally held dataset.""" + recordset = evaluateins_to_recordset( + ins, keep_input=True + ) # This must stay TRUE since ins are in-memory + message = self._wrap_recordset_in_message( + recordset, + message_type=MESSAGE_TYPE_EVALUATE, + timeout=timeout, + group_id=group_id, + ) - def evaluate(client: Client) -> common.EvaluateRes: - return maybe_call_evaluate( - client=client, - evaluate_ins=ins, - ) + message_out = self._submit_job(message, timeout) - res = self._submit_job(evaluate, timeout) - - return cast( - common.EvaluateRes, - res, - ) + return recordset_to_evaluateres(message_out.content) def reconnect( - self, ins: common.ReconnectIns, timeout: Optional[float] + self, + ins: common.ReconnectIns, + timeout: Optional[float], + group_id: Optional[int], ) -> common.DisconnectRes: """Disconnect and (optionally) reconnect later.""" return common.DisconnectRes(reason="") # Nothing to do here (yet) - - -@ray.remote -def launch_and_get_properties( - client_fn: ClientFn, cid: str, get_properties_ins: common.GetPropertiesIns -) -> common.GetPropertiesRes: - """Exectue get_properties remotely.""" - client: Client = _create_client(client_fn, cid) - return maybe_call_get_properties( - client=client, - get_properties_ins=get_properties_ins, - ) - - -@ray.remote -def launch_and_get_parameters( - client_fn: ClientFn, cid: str, get_parameters_ins: common.GetParametersIns -) -> common.GetParametersRes: - """Exectue get_parameters remotely.""" - client: Client = _create_client(client_fn, cid) - return maybe_call_get_parameters( - client=client, - get_parameters_ins=get_parameters_ins, - ) - - -@ray.remote -def launch_and_fit( - client_fn: ClientFn, cid: str, fit_ins: common.FitIns -) -> common.FitRes: - """Exectue fit remotely.""" - client: Client = _create_client(client_fn, cid) - return maybe_call_fit( - client=client, - fit_ins=fit_ins, - ) - - -@ray.remote -def launch_and_evaluate( - client_fn: ClientFn, cid: str, evaluate_ins: common.EvaluateIns -) -> common.EvaluateRes: - """Exectue evaluate remotely.""" - client: Client = _create_client(client_fn, cid) - return maybe_call_evaluate( - client=client, - evaluate_ins=evaluate_ins, - ) - - -def _create_client(client_fn: ClientFn, cid: str) -> Client: - """Create a client instance.""" - # Materialize client - return client_fn(cid) diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py index b380d37d01c8..e59033ad39b5 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py @@ -17,19 +17,29 @@ from math import pi from random import shuffle -from typing import List, Tuple, Type, cast +from typing import Dict, List, Tuple, Type import ray from flwr.client import Client, NumPyClient -from flwr.common import Code, GetPropertiesRes, Status -from flwr.common.configsrecord import ConfigsRecord -from flwr.common.context import Context -from flwr.common.recordset import RecordSet +from flwr.client.client_app import ClientApp +from flwr.common import ( + Config, + ConfigsRecord, + Context, + Message, + Metadata, + RecordSet, + Scalar, +) +from flwr.common.constant import MESSAGE_TYPE_GET_PROPERTIES +from flwr.common.recordset_compat import ( + getpropertiesins_to_recordset, + recordset_to_getpropertiesres, +) +from flwr.common.recordset_compat_test import _get_valid_getpropertiesins from flwr.simulation.ray_transport.ray_actor import ( - ClientRes, - DefaultActor, - JobFn, + ClientAppActor, VirtualClientEngineActor, VirtualClientEngineActorPool, ) @@ -42,34 +52,24 @@ class DummyClient(NumPyClient): def __init__(self, cid: str) -> None: self.cid = int(cid) - -def get_dummy_client(cid: str) -> Client: - """Return a DummyClient converted to Client type.""" - return DummyClient(cid).to_client() - - -# A dummy run -def job_fn(cid: str) -> JobFn: # pragma: no cover - """Construct a simple job with cid dependency.""" - - def cid_times_pi(client: Client) -> ClientRes: # pylint: disable=unused-argument - result = int(cid) * pi + def get_properties(self, config: Config) -> Dict[str, Scalar]: + """Return properties by doing a simple calculation.""" + result = int(self.cid) * pi # store something in context - client.numpy_client.context.state.set_configs( # type: ignore - "result", record=ConfigsRecord({"result": str(result)}) + self.context.state.configs_records["result"] = ConfigsRecord( + {"result": str(result)} ) + return {"result": result} - # now let's convert it to a GetPropertiesRes response - return GetPropertiesRes( - status=Status(Code(0), message="test"), properties={"result": result} - ) - return cid_times_pi +def get_dummy_client(cid: str) -> Client: + """Return a DummyClient converted to Client type.""" + return DummyClient(cid).to_client() def prep( - actor_type: Type[VirtualClientEngineActor] = DefaultActor, + actor_type: Type[VirtualClientEngineActor] = ClientAppActor, ) -> Tuple[List[RayActorClientProxy], VirtualClientEngineActorPool]: # pragma: no cover """Prepare ClientProxies and pool for tests.""" client_resources = {"num_cpus": 1, "num_gpus": 0.0} @@ -104,13 +104,24 @@ def test_cid_consistency_one_at_a_time() -> None: Submit one job and waits for completion. Then submits the next and so on """ proxies, _ = prep() + + getproperties_ins = _get_valid_getpropertiesins() + recordset = getpropertiesins_to_recordset(getproperties_ins) + # submit jobs one at a time for prox in proxies: - res = prox._submit_job( # pylint: disable=protected-access - job_fn=job_fn(prox.cid), timeout=None + message = prox._wrap_recordset_in_message( # pylint: disable=protected-access + recordset, + MESSAGE_TYPE_GET_PROPERTIES, + timeout=None, + group_id=0, ) + message_out = prox._submit_job( # pylint: disable=protected-access + message=message, timeout=None + ) + + res = recordset_to_getpropertiesres(message_out.content) - res = cast(GetPropertiesRes, res) assert int(prox.cid) * pi == res.properties["result"] ray.shutdown() @@ -125,6 +136,9 @@ def test_cid_consistency_all_submit_first_run_consistency() -> None: proxies, _ = prep() run_id = 0 + getproperties_ins = _get_valid_getpropertiesins() + recordset = getpropertiesins_to_recordset(getproperties_ins) + # submit all jobs (collect later) shuffle(proxies) for prox in proxies: @@ -133,25 +147,32 @@ def test_cid_consistency_all_submit_first_run_consistency() -> None: # Retrieve state state = prox.proxy_state.retrieve_context(run_id=run_id) - job = job_fn(prox.cid) + message = prox._wrap_recordset_in_message( # pylint: disable=protected-access + recordset, + message_type=MESSAGE_TYPE_GET_PROPERTIES, + timeout=None, + group_id=0, + ) prox.actor_pool.submit_client_job( - lambda a, c_fn, j_fn, cid, state: a.run.remote(c_fn, j_fn, cid, state), - (prox.client_fn, job, prox.cid, state), + lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state), + (prox.app_fn, message, prox.cid, state), ) # fetch results one at a time shuffle(proxies) for prox in proxies: - res, updated_context = prox.actor_pool.get_client_result(prox.cid, timeout=None) + message_out, updated_context = prox.actor_pool.get_client_result( + prox.cid, timeout=None + ) prox.proxy_state.update_context(run_id, context=updated_context) - res = cast(GetPropertiesRes, res) + res = recordset_to_getpropertiesres(message_out.content) assert int(prox.cid) * pi == res.properties["result"] assert ( str(int(prox.cid) * pi) - == prox.proxy_state.retrieve_context(run_id).state.get_configs("result")[ + == prox.proxy_state.retrieve_context(run_id).state.configs_records[ "result" - ] + ]["result"] ) ray.shutdown() @@ -163,20 +184,39 @@ def test_cid_consistency_without_proxies() -> None: num_clients = len(proxies) cids = [str(cid) for cid in range(num_clients)] + getproperties_ins = _get_valid_getpropertiesins() + recordset = getpropertiesins_to_recordset(getproperties_ins) + + def _load_app() -> ClientApp: + return ClientApp(client_fn=get_dummy_client) + # submit all jobs (collect later) shuffle(cids) for cid in cids: - job = job_fn(cid) + message = Message( + content=recordset, + metadata=Metadata( + run_id=0, + message_id="", + group_id=str(0), + src_node_id=0, + dst_node_id=12345, + reply_to_message="", + ttl="", + message_type=MESSAGE_TYPE_GET_PROPERTIES, + partition_id=int(cid), + ), + ) pool.submit_client_job( lambda a, c_fn, j_fn, cid_, state: a.run.remote(c_fn, j_fn, cid_, state), - (get_dummy_client, job, cid, Context(state=RecordSet())), + (_load_app, message, cid, Context(state=RecordSet())), ) # fetch results one at a time shuffle(cids) for cid in cids: - res, _ = pool.get_client_result(cid, timeout=None) - res = cast(GetPropertiesRes, res) + message_out, _ = pool.get_client_result(cid, timeout=None) + res = recordset_to_getpropertiesres(message_out.content) assert int(cid) * pi == res.properties["result"] ray.shutdown() diff --git a/src/py/flwr/simulation/ray_transport/utils.py b/src/py/flwr/simulation/ray_transport/utils.py index dd9fb6b2aa85..3861164998a4 100644 --- a/src/py/flwr/simulation/ray_transport/utils.py +++ b/src/py/flwr/simulation/ray_transport/utils.py @@ -18,7 +18,6 @@ import warnings from logging import ERROR -from flwr.client import Client from flwr.common.logger import log try: @@ -60,25 +59,3 @@ def enable_tf_gpu_growth() -> None: log(ERROR, traceback.format_exc()) log(ERROR, ex) raise ex - - -def check_clientfn_returns_client(client: Client) -> Client: - """Warn once that clients returned in `clinet_fn` should be of type Client. - - This is here for backwards compatibility. If a ClientFn is provided returning - a different type of client (e.g. NumPyClient) we'll warn the user but convert - the client internally to `Client` by calling `.to_client()`. - """ - if not isinstance(client, Client): - mssg = ( - " Ensure your client is of type `flwr.client.Client`. Please convert it" - " using the `.to_client()` method before returning it" - " in the `client_fn` you pass to `start_simulation`." - " We have applied this conversion on your behalf." - " Not returning a `Client` might trigger an error in future" - " versions of Flower." - ) - - warnings.warn(mssg, DeprecationWarning, stacklevel=2) - client = client.to_client() - return client diff --git a/src/py/flwr_experimental/baseline/tf_cifar/settings.py b/src/py/flwr_experimental/baseline/tf_cifar/settings.py index 829856af0ace..ed1a72cafac9 100644 --- a/src/py/flwr_experimental/baseline/tf_cifar/settings.py +++ b/src/py/flwr_experimental/baseline/tf_cifar/settings.py @@ -120,9 +120,9 @@ def configure_clients( cid=str(i), partition=i, # Indices 0 to 49 fast, 50 to 99 slow - delay_factor=delay_factor_fast - if i < int(num_clients / 2) - else delay_factor_slow, + delay_factor=( + delay_factor_fast if i < int(num_clients / 2) else delay_factor_slow + ), # Shared iid_fraction=iid_fraction, num_clients=num_clients, diff --git a/src/py/flwr_experimental/baseline/tf_fashion_mnist/settings.py b/src/py/flwr_experimental/baseline/tf_fashion_mnist/settings.py index b0de9841de30..72adc1f0be04 100644 --- a/src/py/flwr_experimental/baseline/tf_fashion_mnist/settings.py +++ b/src/py/flwr_experimental/baseline/tf_fashion_mnist/settings.py @@ -141,9 +141,9 @@ def configure_clients( cid=str(i), partition=i, # Indices 0 to 49 fast, 50 to 99 slow - delay_factor=delay_factor_fast - if i < int(num_clients / 2) - else delay_factor_slow, + delay_factor=( + delay_factor_fast if i < int(num_clients / 2) else delay_factor_slow + ), # Shared iid_fraction=iid_fraction, num_clients=num_clients, diff --git a/src/py/flwr_experimental/baseline/tf_hotkey/settings.py b/src/py/flwr_experimental/baseline/tf_hotkey/settings.py index 18c58d1333a5..5bfb7b1e42ad 100644 --- a/src/py/flwr_experimental/baseline/tf_hotkey/settings.py +++ b/src/py/flwr_experimental/baseline/tf_hotkey/settings.py @@ -144,9 +144,9 @@ def configure_clients( cid=str(i), partition=i, # Indices 0 to 49 fast, 50 to 99 slow - delay_factor=delay_factor_fast - if i < int(num_clients / 2) - else delay_factor_slow, + delay_factor=( + delay_factor_fast if i < int(num_clients / 2) else delay_factor_slow + ), # Shared iid_fraction=iid_fraction, num_clients=num_clients, diff --git a/src/py/flwr_experimental/logserver/server.py b/src/py/flwr_experimental/logserver/server.py index dc729cebbf33..683b12b6db6c 100644 --- a/src/py/flwr_experimental/logserver/server.py +++ b/src/py/flwr_experimental/logserver/server.py @@ -69,9 +69,9 @@ def upload_file(local_filepath: str, s3_key: Optional[str]) -> None: Bucket=CONFIG["s3_bucket"], Key=s3_key, ExtraArgs={ - "ContentType": "application/pdf" - if s3_key.endswith(".pdf") - else "text/plain" + "ContentType": ( + "application/pdf" if s3_key.endswith(".pdf") else "text/plain" + ) }, ) # pylint: disable=broad-except diff --git a/src/py/flwr_experimental/ops/cluster.py b/src/py/flwr_experimental/ops/cluster.py index dfd0dd13727e..692704676f4a 100644 --- a/src/py/flwr_experimental/ops/cluster.py +++ b/src/py/flwr_experimental/ops/cluster.py @@ -73,10 +73,10 @@ def ssh_connection( username, key_filename = ssh_credentials instance_ssh_port: int = cast(int, instance.ssh_port) - ignore_host_key_policy: Union[ - Type[MissingHostKeyPolicy], MissingHostKeyPolicy - ] = cast( - Union[Type[MissingHostKeyPolicy], MissingHostKeyPolicy], IgnoreHostKeyPolicy + ignore_host_key_policy: Union[Type[MissingHostKeyPolicy], MissingHostKeyPolicy] = ( + cast( + Union[Type[MissingHostKeyPolicy], MissingHostKeyPolicy], IgnoreHostKeyPolicy + ) ) client = SSHClient()