diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index ccf031344f67..3e314c8d1de5 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -35,3 +35,4 @@ README.md @jafermarq @tanertopal @danieljanes /.devcontainer @Robert-Steiner @Moep90 **/Dockerfile @Robert-Steiner @Moep90 **/*.Dockerfile @Robert-Steiner @Moep90 +src/docker @Robert-Steiner @Moep90 diff --git a/.github/workflows/docker-build-main.yml b/.github/workflows/docker-build-main.yml new file mode 100644 index 000000000000..81ef845eae29 --- /dev/null +++ b/.github/workflows/docker-build-main.yml @@ -0,0 +1,69 @@ +name: Build Docker Images Main Branch + +on: + push: + branches: + - 'main' + +jobs: + parameters: + if: github.repository == 'adap/flower' + name: Collect docker build parameters + runs-on: ubuntu-22.04 + timeout-minutes: 10 + outputs: + pip-version: ${{ steps.versions.outputs.pip-version }} + setuptools-version: ${{ steps.versions.outputs.setuptools-version }} + flwr-version-ref: ${{ steps.versions.outputs.flwr-version-ref }} + steps: + - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 + + - uses: ./.github/actions/bootstrap + id: bootstrap + + - id: versions + run: | + echo "pip-version=${{ steps.bootstrap.outputs.pip-version }}" >> "$GITHUB_OUTPUT" + echo "setuptools-version=${{ steps.bootstrap.outputs.setuptools-version }}" >> "$GITHUB_OUTPUT" + echo "flwr-version-ref=git+${{ github.server_url }}/${{ github.repository }}.git@${{ github.sha }}" >> "$GITHUB_OUTPUT" + + build-docker-base-images: + name: Build base images + if: github.repository == 'adap/flower' + uses: ./.github/workflows/_docker-build.yml + needs: parameters + with: + namespace-repository: flwr/base + file-dir: src/docker/base/ubuntu + build-args: | + PIP_VERSION=${{ needs.parameters.outputs.pip-version }} + SETUPTOOLS_VERSION=${{ needs.parameters.outputs.setuptools-version }} + FLWR_VERSION_REF=${{ needs.parameters.outputs.flwr-version-ref }} + tags: unstable + secrets: + dockerhub-user: ${{ secrets.DOCKERHUB_USERNAME }} + dockerhub-token: ${{ secrets.DOCKERHUB_TOKEN }} + + build-docker-binary-images: + name: Build binary images + if: github.repository == 'adap/flower' + uses: ./.github/workflows/_docker-build.yml + needs: build-docker-base-images + strategy: + fail-fast: false + matrix: + images: [ + { repository: "flwr/superlink", file_dir: "src/docker/superlink" }, + { repository: "flwr/supernode", file_dir: "src/docker/supernode" }, + { repository: "flwr/serverapp", file_dir: "src/docker/serverapp" }, + { repository: "flwr/superexec", file_dir: "src/docker/superexec" }, + { repository: "flwr/clientapp", file_dir: "src/docker/clientapp" } + ] + with: + namespace-repository: ${{ matrix.images.repository }} + file-dir: ${{ matrix.images.file_dir }} + build-args: BASE_IMAGE=unstable + tags: unstable + secrets: + dockerhub-user: ${{ secrets.DOCKERHUB_USERNAME }} + dockerhub-token: ${{ secrets.DOCKERHUB_TOKEN }} diff --git a/.github/workflows/e2e.yml b/.github/workflows/e2e.yml index 49e5b7bf1b36..815d6422848b 100644 --- a/.github/workflows/e2e.yml +++ b/.github/workflows/e2e.yml @@ -146,8 +146,6 @@ jobs: if: ${{ github.repository == 'adap/flower' && !github.event.pull_request.head.repo.fork && github.actor != 'dependabot[bot]' }} run: | python -m pip install https://${{ env.ARTIFACT_BUCKET }}/py/${{ needs.wheel.outputs.dir }}/${{ needs.wheel.outputs.short_sha }}/${{ needs.wheel.outputs.whl_path }} - - name: Install e2e components - run: pip install . - name: Download dataset if: ${{ matrix.dataset }} run: python -c "${{ matrix.dataset }}" @@ -172,7 +170,7 @@ jobs: run: ./../test_superlink.sh bare sqlite - name: Run driver test with client authentication if: ${{ matrix.directory == 'e2e-bare-auth' }} - run: ./../test_superlink.sh bare client-auth + run: ./../test_superlink.sh "${{ matrix.directory }}" client-auth - name: Run reconnection test with SQLite database if: ${{ matrix.directory == 'e2e-bare' }} run: ./../test_reconnection.sh sqlite diff --git a/.github/workflows/framework-release.yml b/.github/workflows/framework-release.yml index 812d5b1e398e..e608329872de 100644 --- a/.github/workflows/framework-release.yml +++ b/.github/workflows/framework-release.yml @@ -16,6 +16,8 @@ jobs: if: ${{ github.repository == 'adap/flower' }} name: Publish release runs-on: ubuntu-22.04 + outputs: + flwr-version: ${{ steps.publish.outputs.flwr-version }} steps: - name: Checkout code uses: actions/checkout@v4 @@ -26,10 +28,12 @@ jobs: uses: ./.github/actions/bootstrap - name: Get artifacts and publish + id: publish env: GITHUB_REF: ${{ github.ref }} run: | TAG_NAME=$(echo "${GITHUB_REF_NAME}" | cut -c2-) + echo "flwr-version=$TAG_NAME" >> "$GITHUB_OUTPUT" wheel_name="flwr-${TAG_NAME}-py3-none-any.whl" tar_name="flwr-${TAG_NAME}.tar.gz" @@ -67,8 +71,7 @@ jobs: - id: matrix run: | - FLWR_VERSION=$(poetry version -s) - python dev/build-docker-image-matrix.py --flwr-version "${FLWR_VERSION}" > matrix.json + python dev/build-docker-image-matrix.py --flwr-version "${{ needs.publish.outputs.flwr-version }}" > matrix.json echo "matrix=$(cat matrix.json)" >> $GITHUB_OUTPUT build-base-images: diff --git a/benchmarks/flowertune-llm/evaluation/README.md b/benchmarks/flowertune-llm/evaluation/README.md index 1b6383df296a..d7216c089d8a 100644 --- a/benchmarks/flowertune-llm/evaluation/README.md +++ b/benchmarks/flowertune-llm/evaluation/README.md @@ -37,7 +37,7 @@ The default template generated by `flwr new` (see the [Project Creation Instruct | | MBPP | HumanEval | MultiPL-E (JS) | MultiPL-E (C++) | Avg | |:----------:|:-----:|:---------:|:--------------:|:---------------:|:-----:| -| Pass@1 (%) | 32.60 | 26.83 | 29.81 | 24.22 | 28.37 | +| Pass@1 (%) | 31.60 | 23.78 | 28.57 | 25.47 | 27.36 | ## Make submission on FlowerTune LLM Leaderboard diff --git a/benchmarks/flowertune-llm/evaluation/general-nlp/README.md b/benchmarks/flowertune-llm/evaluation/general-nlp/README.md index 51c801494f6d..18666968108d 100644 --- a/benchmarks/flowertune-llm/evaluation/general-nlp/README.md +++ b/benchmarks/flowertune-llm/evaluation/general-nlp/README.md @@ -23,7 +23,7 @@ huggingface-cli login Download data from [FastChat](https://github.com/lm-sys/FastChat): ```shell -git clone --depth=1 https://github.com/lm-sys/FastChat.git && cd FastChat && git checkout d561f87b24de197e25e3ddf7e09af93ced8dfe36 && mv fastchat/llm_judge/data ../data && cd .. && rm -rf FastChat +git clone https://github.com/lm-sys/FastChat.git && cd FastChat && git checkout d561f87b24de197e25e3ddf7e09af93ced8dfe36 && mv fastchat/llm_judge/data ../data && cd .. && rm -rf FastChat ``` diff --git a/benchmarks/flowertune-llm/evaluation/medical/README.md b/benchmarks/flowertune-llm/evaluation/medical/README.md new file mode 100644 index 000000000000..78de069460d8 --- /dev/null +++ b/benchmarks/flowertune-llm/evaluation/medical/README.md @@ -0,0 +1,38 @@ +# Evaluation for Medical challenge + +We build up a medical question answering (QA) pipeline to evaluate our fined-tuned LLMs. +Three datasets have been selected for this evaluation: [PubMedQA](https://huggingface.co/datasets/bigbio/pubmed_qa), [MedMCQA](https://huggingface.co/datasets/medmcqa), and [MedQA](https://huggingface.co/datasets/bigbio/med_qa). + + +## Environment Setup + +```shell +git clone --depth=1 https://github.com/adap/flower.git && mv flower/benchmarks/flowertune-llm/evaluation/medical ./flowertune-eval-medical && rm -rf flower && cd flowertune-eval-medical +``` + +Create a new Python environment (we recommend Python 3.10), activate it, then install dependencies with: + +```shell +# From a new python environment, run: +pip install -r requirements.txt + +# Log in HuggingFace account +huggingface-cli login +``` + +## Generate model decision & calculate accuracy + +```bash +python eval.py \ +--peft-path=/path/to/fine-tuned-peft-model-dir/ \ # e.g., ./peft_1 +--run-name=fl \ # specified name for this run +--batch-size=16 \ +--quantization=4 \ +--datasets=pubmedqa,medmcqa,medqa +``` + +The model answers and accuracy values will be saved to `benchmarks/generation_{dataset_name}_{run_name}.jsonl` and `benchmarks/acc_{dataset_name}_{run_name}.txt`, respectively. + + +> [!NOTE] +> Please ensure that you provide all **three accuracy values (PubMedQA, MedMCQA, MedQA)** for three evaluation datasets when submitting to the LLM Leaderboard (see the [`Make Submission`](https://github.com/adap/flower/tree/main/benchmarks/flowertune-llm/evaluation#make-submission-on-flowertune-llm-leaderboard) section). diff --git a/benchmarks/flowertune-llm/evaluation/medical/benchmarks.py b/benchmarks/flowertune-llm/evaluation/medical/benchmarks.py new file mode 100644 index 000000000000..c72e2a7894da --- /dev/null +++ b/benchmarks/flowertune-llm/evaluation/medical/benchmarks.py @@ -0,0 +1,174 @@ +import json + +import pandas as pd +from sklearn.metrics import accuracy_score +from torch.utils.data import DataLoader +from tqdm import tqdm +from utils import format_answer, format_example, save_results + +import datasets + +# The instructions refer to Meditron evaluation: +# https://github.com/epfLLM/meditron/blob/main/evaluation/instructions.json +INSTRUCTIONS = { + "pubmedqa": "As an expert doctor in clinical science and medical knowledge, can you tell me if the following statement is correct? Answer yes, no, or maybe.", + "medqa": "You are a medical doctor taking the US Medical Licensing Examination. You need to demonstrate your understanding of basic and clinical science, medical knowledge, and mechanisms underlying health, disease, patient care, and modes of therapy. Show your ability to apply the knowledge essential for medical practice. For the following multiple-choice question, select one correct answer from A to E. Base your answer on the current and standard practices referenced in medical guidelines.", + "medmcqa": "You are a medical doctor answering realworld medical entrance exam questions. Based on your understanding of basic and clinical science, medical knowledge, and mechanisms underlying health, disease, patient care, and modes of therapy, answer the following multiple-choice question. Select one correct answer from A to D. Base your answer on the current and standard practices referenced in medical guidelines.", +} + + +def infer_pubmedqa(model, tokenizer, batch_size, run_name): + name = "pubmedqa" + answer_type = "boolean" + dataset = datasets.load_dataset( + "bigbio/pubmed_qa", + "pubmed_qa_labeled_fold0_source", + split="test", + trust_remote_code=True, + ) + # Post process + instruction = INSTRUCTIONS[name] + + def post_process(row): + context = "\n".join(row["CONTEXTS"]) + row["prompt"] = f"{context}\n{row['QUESTION']}" + row["gold"] = row["final_decision"] + row["long_answer"] = row["LONG_ANSWER"] + row["prompt"] = f"{instruction}\n{row['prompt']}\nThe answer is:\n" + return row + + dataset = dataset.map(post_process) + + # Generate results + generate_results(name, run_name, dataset, model, tokenizer, batch_size, answer_type) + + +def infer_medqa(model, tokenizer, batch_size, run_name): + name = "medqa" + answer_type = "mcq" + dataset = datasets.load_dataset( + "bigbio/med_qa", + "med_qa_en_4options_source", + split="test", + trust_remote_code=True, + ) + + # Post process + instruction = INSTRUCTIONS[name] + + def post_process(row): + choices = [opt["value"] for opt in row["options"]] + row["prompt"] = format_example(row["question"], choices) + for opt in row["options"]: + if opt["value"] == row["answer"]: + row["gold"] = opt["key"] + break + row["prompt"] = f"{instruction}\n{row['prompt']}\nThe answer is:\n" + return row + + dataset = dataset.map(post_process) + + # Generate results + generate_results(name, run_name, dataset, model, tokenizer, batch_size, answer_type) + + +def infer_medmcqa(model, tokenizer, batch_size, run_name): + name = "medmcqa" + answer_type = "mcq" + dataset = datasets.load_dataset( + "medmcqa", split="validation", trust_remote_code=True + ) + + # Post process + instruction = INSTRUCTIONS[name] + + def post_process(row): + options = [row["opa"], row["opb"], row["opc"], row["opd"]] + answer = int(row["cop"]) + row["prompt"] = format_example(row["question"], options) + row["gold"] = chr(ord("A") + answer) if answer in [0, 1, 2, 3] else None + row["prompt"] = f"{instruction}\n{row['prompt']}\nThe answer is:\n" + return row + + dataset = dataset.map(post_process) + + # Generate results + generate_results(name, run_name, dataset, model, tokenizer, batch_size, answer_type) + + +def generate_results( + name, run_name, dataset, model, tokenizer, batch_size, answer_type +): + # Run inference + prediction = inference(dataset, model, tokenizer, batch_size) + + # Calculate accuracy + acc = accuracy_compute(prediction, answer_type) + + # Save results and generations + save_results(name, run_name, prediction, acc) + + +def inference(dataset, model, tokenizer, batch_size): + columns_process = ["prompt", "gold"] + dataset_process = pd.DataFrame(dataset, columns=dataset.features)[columns_process] + dataset_process = dataset_process.assign(output="Null") + temperature = 1.0 + + inference_data = json.loads(dataset_process.to_json(orient="records")) + data_loader = DataLoader(inference_data, batch_size=batch_size, shuffle=False) + + batch_counter = 0 + for batch in tqdm(data_loader, total=len(data_loader), position=0, leave=True): + prompts = [ + f"<|im_start|>question\n{prompt}<|im_end|>\n<|im_start|>answer\n" + for prompt in batch["prompt"] + ] + if batch_counter == 0: + print(prompts[0]) + + # Process tokenizer + stop_seq = ["###"] + if tokenizer.eos_token is not None: + stop_seq.append(tokenizer.eos_token) + if tokenizer.pad_token is not None: + stop_seq.append(tokenizer.pad_token) + max_new_tokens = len( + tokenizer(batch["gold"][0], add_special_tokens=False)["input_ids"] + ) + + outputs = [] + for prompt in prompts: + input_ids = tokenizer.encode(prompt, return_tensors="pt").to("cuda") + output_ids = model.generate( + inputs=input_ids, + max_new_tokens=max_new_tokens, + do_sample=False, + top_p=1.0, + temperature=temperature, + pad_token_id=tokenizer.eos_token_id, + ) + output_ids = output_ids[0][len(input_ids[0]) :] + output = tokenizer.decode(output_ids, skip_special_tokens=True) + outputs.append(output) + + for prompt, out in zip(batch["prompt"], outputs): + dataset_process.loc[dataset_process["prompt"] == prompt, "output"] = out + batch_counter += 1 + + return dataset_process + + +def accuracy_compute(dataset, answer_type): + dataset = json.loads(dataset.to_json(orient="records")) + preds, golds = [], [] + for row in dataset: + answer = row["gold"].lower() + output = row["output"].lower() + pred, gold = format_answer(output, answer, answer_type=answer_type) + preds.append(pred) + golds.append(gold) + + accuracy = accuracy_score(preds, golds) + + return accuracy diff --git a/benchmarks/flowertune-llm/evaluation/medical/eval.py b/benchmarks/flowertune-llm/evaluation/medical/eval.py new file mode 100644 index 000000000000..7405e1493e4d --- /dev/null +++ b/benchmarks/flowertune-llm/evaluation/medical/eval.py @@ -0,0 +1,62 @@ +import argparse + +import torch +from peft import PeftModel +from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig + +from benchmarks import infer_medmcqa, infer_medqa, infer_pubmedqa + +# Fixed seed +torch.manual_seed(2024) + +parser = argparse.ArgumentParser() +parser.add_argument( + "--base-model-name-path", type=str, default="mistralai/Mistral-7B-v0.3" +) +parser.add_argument("--run-name", type=str, default="fl") +parser.add_argument("--peft-path", type=str, default=None) +parser.add_argument( + "--datasets", + type=str, + default="pubmedqa", + help="The dataset to infer on: [pubmedqa, medqa, medmcqa]", +) +parser.add_argument("--batch-size", type=int, default=16) +parser.add_argument("--quantization", type=int, default=4) +args = parser.parse_args() + + +# Load model and tokenizer +if args.quantization == 4: + quantization_config = BitsAndBytesConfig(load_in_4bit=True) + torch_dtype = torch.float32 +elif args.quantization == 8: + quantization_config = BitsAndBytesConfig(load_in_8bit=True) + torch_dtype = torch.float16 +else: + raise ValueError( + f"Use 4-bit or 8-bit quantization. You passed: {args.quantization}/" + ) + +model = AutoModelForCausalLM.from_pretrained( + args.base_model_name_path, + quantization_config=quantization_config, + torch_dtype=torch_dtype, +) +if args.peft_path is not None: + model = PeftModel.from_pretrained( + model, args.peft_path, torch_dtype=torch_dtype + ).to("cuda") + +tokenizer = AutoTokenizer.from_pretrained(args.base_model_name_path) + +# Evaluate +for dataset in args.datasets.split(","): + if dataset == "pubmedqa": + infer_pubmedqa(model, tokenizer, args.batch_size, args.run_name) + elif dataset == "medqa": + infer_medqa(model, tokenizer, args.batch_size, args.run_name) + elif dataset == "medmcqa": + infer_medmcqa(model, tokenizer, args.batch_size, args.run_name) + else: + raise ValueError("Undefined Dataset.") diff --git a/benchmarks/flowertune-llm/evaluation/medical/requirements.txt b/benchmarks/flowertune-llm/evaluation/medical/requirements.txt new file mode 100644 index 000000000000..adfc8b0c59db --- /dev/null +++ b/benchmarks/flowertune-llm/evaluation/medical/requirements.txt @@ -0,0 +1,7 @@ +peft==0.6.2 +pandas==2.2.2 +scikit-learn==1.5.0 +datasets==2.20.0 +sentencepiece==0.2.0 +protobuf==5.27.1 +bitsandbytes==0.43.1 diff --git a/benchmarks/flowertune-llm/evaluation/medical/utils.py b/benchmarks/flowertune-llm/evaluation/medical/utils.py new file mode 100644 index 000000000000..44d0763d39d4 --- /dev/null +++ b/benchmarks/flowertune-llm/evaluation/medical/utils.py @@ -0,0 +1,81 @@ +import os +import re + + +def format_example(question, choices): + if not question.endswith("?") and not question.endswith("."): + question += "?" + options_str = "\n".join([f"{chr(65+i)}. {choices[i]}" for i in range(len(choices))]) + prompt = "Question: " + question + "\n\nOptions:\n" + options_str + return prompt + + +def save_results(dataset_name, run_name, dataset, acc): + path = "./benchmarks/" + if not os.path.exists(path): + os.makedirs(path) + + # Save results + results_path = os.path.join(path, f"acc_{dataset_name}_{run_name}.txt") + with open(results_path, "w") as f: + f.write(f"Accuracy: {acc}. ") + print(f"Accuracy: {acc}. ") + + # Save generations + generation_path = os.path.join(path, f"generation_{dataset_name}_{run_name}.jsonl") + dataset.to_json(generation_path, orient="records") + + +def format_answer(output_full, answer, answer_type="mcq"): + output = output_full + default = (output_full, answer) + if "\n##" in output: + try: + output = output.split("\n##")[1].split("\n")[0].strip().lower() + except Exception: + return default + if "###" in answer: + try: + answer = answer.split("answer is:")[1].split("###")[0].strip() + except Exception: + return default + + output = re.sub(r"[^a-zA-Z0-9]", " ", output).strip() + output = re.sub(" +", " ", output) + + if answer_type == "boolean": + output = clean_boolean_answer(output) + elif answer_type == "mcq": + output = clean_mcq_answer(output) + + if output in ["a", "b", "c", "d", "e", "yes", "no"]: + return output, answer + else: + return default + + +def clean_mcq_answer(output): + output = clean_answer(output) + try: + output = output[0] + except Exception: + return output + return output + + +def clean_boolean_answer(output): + if "yesyes" in output: + output = output.replace("yesyes", "yes") + elif "nono" in output: + output = output.replace("nono", "no") + elif "yesno" in output: + output = output.replace("yesno", "yes") + elif "noyes" in output: + output = output.replace("noyes", "no") + output = clean_answer(output) + return output + + +def clean_answer(output): + output_clean = output.encode("ascii", "ignore").decode("ascii") + return output_clean diff --git a/doc/source/contributor-how-to-build-docker-images.rst b/doc/source/contributor-how-to-build-docker-images.rst index 522d124dfd9b..d6acad4afa03 100644 --- a/doc/source/contributor-how-to-build-docker-images.rst +++ b/doc/source/contributor-how-to-build-docker-images.rst @@ -26,7 +26,7 @@ Before we can start, we need to meet a few prerequisites in our local developmen default values, others must be specified when building the image. All available build arguments for each image are listed in one of the tables below. -Building the base image +Building the Base Image ----------------------- .. list-table:: @@ -65,6 +65,10 @@ Building the base image - The Flower package to be installed. - No - ``flwr`` or ``flwr-nightly`` + * - ``FLWR_VERSION_REF`` + - A `direct reference `_ without the ``@`` specifier. If both ``FLWR_VERSION`` and ``FLWR_VERSION_REF`` are specified, the ``FLWR_VERSION_REF`` has precedence. + - No + - `Direct Reference Examples`_ The following example creates a base Ubuntu/Alpine image with Python ``3.11.0``, pip :substitution-code:`|pip_version|`, setuptools :substitution-code:`|setuptools_version|` @@ -84,8 +88,8 @@ and Flower :substitution-code:`|stable_flwr_version|`: In this example, we specify our image name as ``flwr_base`` and the tag as ``0.1.0``. Remember that the build arguments as well as the name and tag can be adapted to your needs. These values serve as examples only. -Building the SuperLink/SuperNode or ServerApp image ---------------------------------------------------- +Building a Flower Binary Image +------------------------------ .. list-table:: :widths: 25 45 15 15 @@ -130,3 +134,21 @@ After creating the image, we can test whether the image is working: .. code-block:: bash $ docker run --rm flwr_superlink:0.1.0 --help + +Direct Reference Examples +------------------------- + +.. code-block:: bash + :substitutions: + + # main branch + git+https://github.com/adap/flower.git@main + + # commit hash + git+https://github.com/adap/flower.git@1187c707f1894924bfa693d99611cf6f93431835 + + # tag + git+https://github.com/adap/flower.git@|stable_flwr_version| + + # artifact store + https://artifact.flower.ai/py/main/latest/flwr-|stable_flwr_version|-py3-none-any.whl diff --git a/doc/source/tutorial-quickstart-mlx.rst b/doc/source/tutorial-quickstart-mlx.rst index 0999bf44d3b7..675a08502d26 100644 --- a/doc/source/tutorial-quickstart-mlx.rst +++ b/doc/source/tutorial-quickstart-mlx.rst @@ -109,7 +109,7 @@ You can also override the parameters defined in .. code:: shell # Override some arguments - $ flwr run . --run-config num-server-rounds=5,lr=0.05 + $ flwr run . --run-config "num-server-rounds=5 lr=0.05" What follows is an explanation of each component in the project you just created: dataset partition, the model, defining the ``ClientApp`` and diff --git a/doc/source/tutorial-quickstart-pytorch.rst b/doc/source/tutorial-quickstart-pytorch.rst index 4515e8d0eeb5..d00b9efbe16b 100644 --- a/doc/source/tutorial-quickstart-pytorch.rst +++ b/doc/source/tutorial-quickstart-pytorch.rst @@ -108,7 +108,7 @@ You can also override the parameters defined in the .. code:: shell # Override some arguments - $ flwr run . --run-config num-server-rounds=5,local-epochs=3 + $ flwr run . --run-config "num-server-rounds=5 local-epochs=3" What follows is an explanation of each component in the project you just created: dataset partition, the model, defining the ``ClientApp`` and diff --git a/e2e/test_superlink.sh b/e2e/test_superlink.sh index 684f386bd388..2016f6da1933 100755 --- a/e2e/test_superlink.sh +++ b/e2e/test_superlink.sh @@ -2,7 +2,7 @@ set -e case "$1" in - e2e-bare-https) + e2e-bare-https | e2e-bare-auth) ./generate.sh server_arg="--ssl-ca-certfile certificates/ca.crt --ssl-certfile certificates/server.pem --ssl-keyfile certificates/server.key" client_arg="--root-certificates certificates/ca.crt" @@ -37,14 +37,11 @@ case "$2" in client_auth_2="" ;; client-auth) - ./generate.sh rest_arg_superlink="" rest_arg_supernode="" server_address="127.0.0.1:9092" server_app_address="127.0.0.1:9091" db_arg="--database :flwr-in-memory-state:" - server_arg="--ssl-ca-certfile certificates/ca.crt --ssl-certfile certificates/server.pem --ssl-keyfile certificates/server.key" - client_arg="--root-certificates certificates/ca.crt" server_auth="--auth-list-public-keys keys/client_public_keys.csv --auth-superlink-private-key keys/server_credentials --auth-superlink-public-key keys/server_credentials.pub" client_auth_1="--auth-supernode-private-key keys/client_credentials_1 --auth-supernode-public-key keys/client_credentials_1.pub" client_auth_2="--auth-supernode-private-key keys/client_credentials_2 --auth-supernode-public-key keys/client_credentials_2.pub" diff --git a/examples/fl-dp-sa/README.md b/examples/fl-dp-sa/README.md index 65c8a5b18fa8..61a6c80f3556 100644 --- a/examples/fl-dp-sa/README.md +++ b/examples/fl-dp-sa/README.md @@ -1,28 +1,63 @@ --- -tags: [basic, vision, fds] +tags: [DP, SecAgg, vision, fds] dataset: [MNIST] framework: [torch, torchvision] --- -# Example of Flower App with DP and SA +# Flower Example on MNIST with Differential Privacy and Secure Aggregation -This is a simple example that utilizes central differential privacy with client-side fixed clipping and secure aggregation. -Note: This example is designed for a small number of rounds and is intended for demonstration purposes. +This example demonstrates a federated learning setup using the Flower, incorporating central differential privacy (DP) with client-side fixed clipping and secure aggregation (SA). It is intended for a small number of rounds for demonstration purposes. -## Install dependencies +This example is similar to the [quickstart-pytorch example](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch) and extends it by integrating central differential privacy and secure aggregation. For more details on differential privacy and secure aggregation in Flower, please refer to the documentation [here](https://flower.ai/docs/framework/how-to-use-differential-privacy.html) and [here](https://flower.ai/docs/framework/contributor-ref-secure-aggregation-protocols.html). -```bash -# Using pip -pip install . +## Set up the project + +### Clone the project + +Start by cloning the example project: + +```shell +git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/fl-dp-sa . && rm -rf flower && cd fl-dp-sa +``` + +This will create a new directory called `fl-dp-sa` containing the following files: -# Or using Poetry -poetry install +```shell +fl-dp-sa +├── fl_dp_sa +│ ├── client_app.py # Defines your ClientApp +│ ├── server_app.py # Defines your ServerApp +│ └── task.py # Defines your model, training, and data loading +├── pyproject.toml # Project metadata like dependencies and configs +└── README.md ``` -## Run +### Install dependencies and project -The example uses the MNIST dataset with a total of 100 clients, with 20 clients sampled in each round. The hyperparameters for DP and SecAgg are specified in `server.py`. +Install the dependencies defined in `pyproject.toml` as well as the `fl_dp_sa` package. ```shell -flower-simulation --server-app fl_dp_sa.server:app --client-app fl_dp_sa.client:app --num-supernodes 100 +# From a new python environment, run: +pip install -e . +``` + +## Run the project + +You can run your Flower project in both _simulation_ and _deployment_ mode without making changes to the code. If you are starting with Flower, we recommend you using the _simulation_ mode as it requires fewer components to be launched manually. By default, `flwr run` will make use of the Simulation Engine. + +### Run with the Simulation Engine + +```bash +flwr run . +``` + +You can also override some of the settings for your `ClientApp` and `ServerApp` defined in `pyproject.toml`. For example: + +```bash +flwr run . --run-config "noise-multiplier=0.1 clipping-norm=5" ``` + +### Run with the Deployment Engine + +> \[!NOTE\] +> An update to this example will show how to run this Flower project with the Deployment Engine and TLS certificates, or with Docker. diff --git a/examples/fl-dp-sa/fl_dp_sa/__init__.py b/examples/fl-dp-sa/fl_dp_sa/__init__.py index 741260348ab8..c5c9a7e9581c 100644 --- a/examples/fl-dp-sa/fl_dp_sa/__init__.py +++ b/examples/fl-dp-sa/fl_dp_sa/__init__.py @@ -1 +1 @@ -"""fl_dp_sa: A Flower / PyTorch app.""" +"""fl_dp_sa: Flower Example using Differential Privacy and Secure Aggregation.""" diff --git a/examples/fl-dp-sa/fl_dp_sa/client.py b/examples/fl-dp-sa/fl_dp_sa/client.py deleted file mode 100644 index b3b02c6e9d61..000000000000 --- a/examples/fl-dp-sa/fl_dp_sa/client.py +++ /dev/null @@ -1,42 +0,0 @@ -"""fl_dp_sa: A Flower / PyTorch app.""" - -from flwr.client import ClientApp, NumPyClient -from flwr.client.mod import fixedclipping_mod, secaggplus_mod - -from fl_dp_sa.task import DEVICE, Net, get_weights, load_data, set_weights, test, train - -# Load model and data (simple CNN, CIFAR-10) -net = Net().to(DEVICE) - - -# Define FlowerClient and client_fn -class FlowerClient(NumPyClient): - def __init__(self, trainloader, testloader) -> None: - self.trainloader = trainloader - self.testloader = testloader - - def fit(self, parameters, config): - set_weights(net, parameters) - results = train(net, self.trainloader, self.testloader, epochs=1, device=DEVICE) - return get_weights(net), len(self.trainloader.dataset), results - - def evaluate(self, parameters, config): - set_weights(net, parameters) - loss, accuracy = test(net, self.testloader) - return loss, len(self.testloader.dataset), {"accuracy": accuracy} - - -def client_fn(cid: str): - """Create and return an instance of Flower `Client`.""" - trainloader, testloader = load_data(partition_id=int(cid)) - return FlowerClient(trainloader, testloader).to_client() - - -# Flower ClientApp -app = ClientApp( - client_fn=client_fn, - mods=[ - secaggplus_mod, - fixedclipping_mod, - ], -) diff --git a/examples/fl-dp-sa/fl_dp_sa/client_app.py b/examples/fl-dp-sa/fl_dp_sa/client_app.py new file mode 100644 index 000000000000..5630d4f4d14f --- /dev/null +++ b/examples/fl-dp-sa/fl_dp_sa/client_app.py @@ -0,0 +1,50 @@ +"""fl_dp_sa: Flower Example using Differential Privacy and Secure Aggregation.""" + +import torch +from flwr.client import ClientApp, NumPyClient +from flwr.common import Context +from flwr.client.mod import fixedclipping_mod, secaggplus_mod + +from fl_dp_sa.task import Net, get_weights, load_data, set_weights, test, train + + +class FlowerClient(NumPyClient): + def __init__(self, trainloader, testloader) -> None: + self.net = Net() + self.trainloader = trainloader + self.testloader = testloader + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + def fit(self, parameters, config): + set_weights(self.net, parameters) + results = train( + self.net, + self.trainloader, + self.testloader, + epochs=1, + device=self.device, + ) + return get_weights(self.net), len(self.trainloader.dataset), results + + def evaluate(self, parameters, config): + set_weights(self.net, parameters) + loss, accuracy = test(self.net, self.testloader, self.device) + return loss, len(self.testloader.dataset), {"accuracy": accuracy} + + +def client_fn(context: Context): + partition_id = context.node_config["partition-id"] + trainloader, testloader = load_data( + partition_id=partition_id, num_partitions=context.node_config["num-partitions"] + ) + return FlowerClient(trainloader, testloader).to_client() + + +# Flower ClientApp +app = ClientApp( + client_fn=client_fn, + mods=[ + secaggplus_mod, + fixedclipping_mod, + ], +) diff --git a/examples/fl-dp-sa/fl_dp_sa/server.py b/examples/fl-dp-sa/fl_dp_sa/server_app.py similarity index 56% rename from examples/fl-dp-sa/fl_dp_sa/server.py rename to examples/fl-dp-sa/fl_dp_sa/server_app.py index 3ec0ba757b0d..1704b4942ff8 100644 --- a/examples/fl-dp-sa/fl_dp_sa/server.py +++ b/examples/fl-dp-sa/fl_dp_sa/server_app.py @@ -1,20 +1,22 @@ -"""fl_dp_sa: A Flower / PyTorch app.""" +"""fl_dp_sa: Flower Example using Differential Privacy and Secure Aggregation.""" from typing import List, Tuple from flwr.common import Context, Metrics, ndarrays_to_parameters -from flwr.server import Driver, LegacyContext, ServerApp, ServerConfig +from flwr.server import ( + Driver, + LegacyContext, + ServerApp, + ServerConfig, +) from flwr.server.strategy import DifferentialPrivacyClientSideFixedClipping, FedAvg from flwr.server.workflow import DefaultWorkflow, SecAggPlusWorkflow from fl_dp_sa.task import Net, get_weights -# Define metric aggregation function def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: examples = [num_examples for num_examples, _ in metrics] - - # Multiply accuracy of each client by number of examples used train_losses = [num_examples * m["train_loss"] for num_examples, m in metrics] train_accuracies = [ num_examples * m["train_accuracy"] for num_examples, m in metrics @@ -22,7 +24,6 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: val_losses = [num_examples * m["val_loss"] for num_examples, m in metrics] val_accuracies = [num_examples * m["val_accuracy"] for num_examples, m in metrics] - # Aggregate and return custom metric (weighted average) return { "train_loss": sum(train_losses) / sum(examples), "train_accuracy": sum(train_accuracies) / sum(examples), @@ -31,30 +32,36 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: } -# Initialize model parameters -ndarrays = get_weights(Net()) -parameters = ndarrays_to_parameters(ndarrays) +app = ServerApp() -# Define strategy -strategy = FedAvg( - fraction_fit=0.2, - fraction_evaluate=0.0, # Disable evaluation for demo purpose - min_fit_clients=20, - min_available_clients=20, - fit_metrics_aggregation_fn=weighted_average, - initial_parameters=parameters, -) -strategy = DifferentialPrivacyClientSideFixedClipping( - strategy, noise_multiplier=0.2, clipping_norm=10, num_sampled_clients=20 -) +@app.main() +def main(driver: Driver, context: Context) -> None: + # Initialize global model + model_weights = get_weights(Net()) + parameters = ndarrays_to_parameters(model_weights) + + # Note: The fraction_fit value is configured based on the DP hyperparameter `num-sampled-clients`. + strategy = FedAvg( + fraction_fit=0.2, + fraction_evaluate=0.0, + min_fit_clients=20, + fit_metrics_aggregation_fn=weighted_average, + initial_parameters=parameters, + ) -app = ServerApp() + noise_multiplier = context.run_config["noise-multiplier"] + clipping_norm = context.run_config["clipping-norm"] + num_sampled_clients = context.run_config["num-sampled-clients"] + strategy = DifferentialPrivacyClientSideFixedClipping( + strategy, + noise_multiplier=noise_multiplier, + clipping_norm=clipping_norm, + num_sampled_clients=num_sampled_clients, + ) -@app.main() -def main(driver: Driver, context: Context) -> None: # Construct the LegacyContext context = LegacyContext( context=context, @@ -65,8 +72,8 @@ def main(driver: Driver, context: Context) -> None: # Create the train/evaluate workflow workflow = DefaultWorkflow( fit_workflow=SecAggPlusWorkflow( - num_shares=7, - reconstruction_threshold=4, + num_shares=context.run_config["num-shares"], + reconstruction_threshold=context.run_config["reconstruction-threshold"], ) ) diff --git a/examples/fl-dp-sa/fl_dp_sa/task.py b/examples/fl-dp-sa/fl_dp_sa/task.py index 5b4fd7dee592..c145cebe1378 100644 --- a/examples/fl-dp-sa/fl_dp_sa/task.py +++ b/examples/fl-dp-sa/fl_dp_sa/task.py @@ -1,24 +1,22 @@ -"""fl_dp_sa: A Flower / PyTorch app.""" +"""fl_dp_sa: Flower Example using Differential Privacy and Secure Aggregation.""" from collections import OrderedDict -from logging import INFO import torch import torch.nn as nn import torch.nn.functional as F -from flwr.common.logger import log from flwr_datasets import FederatedDataset +from flwr_datasets.partitioner import IidPartitioner from torch.utils.data import DataLoader from torchvision.transforms import Compose, Normalize, ToTensor -DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +fds = None # Cache FederatedDataset -class Net(nn.Module): - """Model.""" +class Net(nn.Module): def __init__(self) -> None: - super(Net, self).__init__() + super().__init__() self.conv1 = nn.Conv2d(1, 6, 3, padding=1) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) @@ -36,9 +34,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.fc3(x) -def load_data(partition_id): +def load_data(partition_id: int, num_partitions: int): """Load partition MNIST data.""" - fds = FederatedDataset(dataset="mnist", partitioners={"train": 100}) + + global fds + if fds is None: + partitioner = IidPartitioner(num_partitions=num_partitions) + fds = FederatedDataset( + dataset="ylecun/mnist", + partitioners={"train": partitioner}, + ) partition = fds.load_partition(partition_id) # Divide data on each node: 80% train, 20% test partition_train_test = partition.train_test_split(test_size=0.2, seed=42) @@ -70,8 +75,8 @@ def train(net, trainloader, valloader, epochs, device): loss.backward() optimizer.step() - train_loss, train_acc = test(net, trainloader) - val_loss, val_acc = test(net, valloader) + train_loss, train_acc = test(net, trainloader, device) + val_loss, val_acc = test(net, valloader, device) results = { "train_loss": train_loss, @@ -82,17 +87,17 @@ def train(net, trainloader, valloader, epochs, device): return results -def test(net, testloader): +def test(net, testloader, device): """Validate the model on the test set.""" - net.to(DEVICE) + net.to(device) criterion = torch.nn.CrossEntropyLoss() correct, loss = 0, 0.0 with torch.no_grad(): for batch in testloader: - images = batch["image"].to(DEVICE) - labels = batch["label"].to(DEVICE) - outputs = net(images.to(DEVICE)) - labels = labels.to(DEVICE) + images = batch["image"].to(device) + labels = batch["label"].to(device) + outputs = net(images.to(device)) + labels = labels.to(device) loss += criterion(outputs, labels).item() correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() accuracy = correct / len(testloader.dataset) diff --git a/examples/fl-dp-sa/flower.toml b/examples/fl-dp-sa/flower.toml deleted file mode 100644 index ea2e98206791..000000000000 --- a/examples/fl-dp-sa/flower.toml +++ /dev/null @@ -1,13 +0,0 @@ -[project] -name = "fl_dp_sa" -version = "1.0.0" -description = "" -license = "Apache-2.0" -authors = [ - "The Flower Authors ", -] -readme = "README.md" - -[flower.components] -serverapp = "fl_dp_sa.server:app" -clientapp = "fl_dp_sa.client:app" diff --git a/examples/fl-dp-sa/pyproject.toml b/examples/fl-dp-sa/pyproject.toml index 1ca343b072d9..fbb463cc1c05 100644 --- a/examples/fl-dp-sa/pyproject.toml +++ b/examples/fl-dp-sa/pyproject.toml @@ -1,21 +1,40 @@ [build-system] -requires = ["poetry-core>=1.4.0"] -build-backend = "poetry.core.masonry.api" +requires = ["hatchling"] +build-backend = "hatchling.build" -[tool.poetry] +[project] name = "fl-dp-sa" -version = "0.1.0" -description = "" +version = "1.0.0" +description = "Central Differential Privacy and Secure Aggregation in Flower" license = "Apache-2.0" -authors = [ - "The Flower Authors ", +dependencies = [ + "flwr[simulation]>=1.11.0", + "flwr-datasets[vision]>=0.3.0", + "torch==2.2.1", + "torchvision==0.17.1", ] -readme = "README.md" -[tool.poetry.dependencies] -python = "^3.9" -# Mandatory dependencies -flwr = { version = "^1.8.0", extras = ["simulation"] } -flwr-datasets = { version = "0.0.2", extras = ["vision"] } -torch = "2.2.1" -torchvision = "0.17.1" +[tool.hatch.build.targets.wheel] +packages = ["."] + +[tool.flwr.app] +publisher = "flwrlabs" + +[tool.flwr.app.components] +serverapp = "fl_dp_sa.server_app:app" +clientapp = "fl_dp_sa.client_app:app" + +[tool.flwr.app.config] +# Parameters for the DP +noise-multiplier = 0.2 +clipping-norm = 10 +num-sampled-clients = 20 +# Parameters for the SecAgg+ protocol +num-shares = 7 +reconstruction-threshold = 4 + +[tool.flwr.federations] +default = "local-simulation" + +[tool.flwr.federations.local-simulation] +options.num-supernodes = 100 \ No newline at end of file diff --git a/examples/fl-dp-sa/requirements.txt b/examples/fl-dp-sa/requirements.txt deleted file mode 100644 index f20b9d71e339..000000000000 --- a/examples/fl-dp-sa/requirements.txt +++ /dev/null @@ -1,4 +0,0 @@ -flwr[simulation]>=1.8.0 -flwr-datasets[vision]==0.0.2 -torch==2.2.1 -torchvision==0.17.1 diff --git a/examples/xgboost-quickstart/README.md b/examples/xgboost-quickstart/README.md index fa3e9d0dc6fb..a7b047c090f0 100644 --- a/examples/xgboost-quickstart/README.md +++ b/examples/xgboost-quickstart/README.md @@ -4,7 +4,7 @@ dataset: [HIGGS] framework: [xgboost] --- -# Flower Example using XGBoost +# Federated Learning with XGBoost and Flower (Quickstart Example) This example demonstrates how to perform EXtreme Gradient Boosting (XGBoost) within Flower using `xgboost` package. We use [HIGGS](https://archive.ics.uci.edu/dataset/280/higgs) dataset for this example to perform a binary classification task. @@ -12,72 +12,60 @@ Tree-based with bagging method is used for aggregation on the server. This project provides a minimal code example to enable you to get started quickly. For a more comprehensive code example, take a look at [xgboost-comprehensive](https://github.com/adap/flower/tree/main/examples/xgboost-comprehensive). -## Project Setup +## Set up the project -Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you: +### Clone the project -```shell -git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/xgboost-quickstart . && rm -rf flower && cd xgboost-quickstart -``` - -This will create a new directory called `xgboost-quickstart` containing the following files: - -``` --- README.md <- Your're reading this right now --- server.py <- Defines the server-side logic --- client.py <- Defines the client-side logic --- run.sh <- Commands to run experiments --- pyproject.toml <- Example dependencies -``` - -### Installing Dependencies - -Project dependencies (such as `xgboost` and `flwr`) are defined in `pyproject.toml`. You can install the dependencies by invoking `pip`: +Start by cloning the example project: ```shell -# From a new python environment, run: -pip install . +git clone --depth=1 https://github.com/adap/flower.git _tmp \ + && mv _tmp/examples/xgboost-quickstart . \ + && rm -rf _tmp \ + && cd xgboost-quickstart ``` -Then, to verify that everything works correctly you can run the following command: +This will create a new directory called `xgboost-quickstart` with the following structure: ```shell -python3 -c "import flwr" +xgboost-quickstart +├── xgboost_quickstart +│ ├── __init__.py +│ ├── client_app.py # Defines your ClientApp +│ ├── server_app.py # Defines your ServerApp +│ └── task.py # Defines your utilities and data loading +├── pyproject.toml # Project metadata like dependencies and configs +└── README.md ``` -If you don't see any errors you're good to go! +### Install dependencies and project -## Run Federated Learning with XGBoost and Flower +Install the dependencies defined in `pyproject.toml` as well as the `xgboost_quickstart` package. -Afterwards you are ready to start the Flower server as well as the clients. -You can simply start the server in a terminal as follows: - -```shell -python3 server.py +```bash +pip install -e . ``` -Now you are ready to start the Flower clients which will participate in the learning. -To do so simply open two more terminal windows and run the following commands. +## Run the project -Start client 1 in the first terminal: +You can run your Flower project in both _simulation_ and _deployment_ mode without making changes to the code. If you are starting with Flower, we recommend you using the _simulation_ mode as it requires fewer components to be launched manually. By default, `flwr run` will make use of the Simulation Engine. -```shell -python3 client.py --partition-id=0 +### Run with the Simulation Engine + +```bash +flwr run . ``` -Start client 2 in the second terminal: +You can also override some of the settings for your `ClientApp` and `ServerApp` defined in `pyproject.toml`. For example: -```shell -python3 client.py --partition-id=1 +```bash +flwr run . --run-config "num-server-rounds=5 params.eta=0.05" ``` -You will see that XGBoost is starting a federated training. - -Alternatively, you can use `run.sh` to run the same experiment in a single terminal as follows: +> \[!TIP\] +> For a more detailed walk-through check our [quickstart XGBoost tutorial](https://flower.ai/docs/framework/tutorial-quickstart-xgboost.html) -```shell -poetry run ./run.sh -``` +### Run with the Deployment Engine -Look at the [code](https://github.com/adap/flower/tree/main/examples/xgboost-quickstart) -and [tutorial](https://flower.ai/docs/framework/tutorial-quickstart-xgboost.html) for a detailed explanation. +> \[!NOTE\] +> An update to this example will show how to run this Flower application with the Deployment Engine and TLS certificates, or with Docker. diff --git a/examples/xgboost-quickstart/client.py b/examples/xgboost-quickstart/client.py deleted file mode 100644 index d505a7ede785..000000000000 --- a/examples/xgboost-quickstart/client.py +++ /dev/null @@ -1,207 +0,0 @@ -import argparse -import warnings -from logging import INFO -from typing import Union - -import flwr as fl -import xgboost as xgb -from datasets import Dataset, DatasetDict -from flwr.common import ( - Code, - EvaluateIns, - EvaluateRes, - FitIns, - FitRes, - GetParametersIns, - GetParametersRes, - Parameters, - Status, -) -from flwr.common.logger import log -from flwr_datasets import FederatedDataset -from flwr_datasets.partitioner import IidPartitioner - -warnings.filterwarnings("ignore", category=UserWarning) - -# Define arguments parser for the client/partition ID. -parser = argparse.ArgumentParser() -parser.add_argument( - "--partition-id", - default=0, - type=int, - help="Partition ID used for the current client.", -) -args = parser.parse_args() - - -# Define data partitioning related functions -def train_test_split(partition: Dataset, test_fraction: float, seed: int): - """Split the data into train and validation set given split rate.""" - train_test = partition.train_test_split(test_size=test_fraction, seed=seed) - partition_train = train_test["train"] - partition_test = train_test["test"] - - num_train = len(partition_train) - num_test = len(partition_test) - - return partition_train, partition_test, num_train, num_test - - -def transform_dataset_to_dmatrix(data: Union[Dataset, DatasetDict]) -> xgb.core.DMatrix: - """Transform dataset to DMatrix format for xgboost.""" - x = data["inputs"] - y = data["label"] - new_data = xgb.DMatrix(x, label=y) - return new_data - - -# Load (HIGGS) dataset and conduct partitioning -# We use a small subset (num_partitions=30) of the dataset for demonstration to speed up the data loading process. -partitioner = IidPartitioner(num_partitions=30) -fds = FederatedDataset(dataset="jxie/higgs", partitioners={"train": partitioner}) - -# Load the partition for this `partition_id` -log(INFO, "Loading partition...") -partition = fds.load_partition(partition_id=args.partition_id, split="train") -partition.set_format("numpy") - -# Train/test splitting -train_data, valid_data, num_train, num_val = train_test_split( - partition, test_fraction=0.2, seed=42 -) - -# Reformat data to DMatrix for xgboost -log(INFO, "Reformatting data...") -train_dmatrix = transform_dataset_to_dmatrix(train_data) -valid_dmatrix = transform_dataset_to_dmatrix(valid_data) - -# Hyper-parameters for xgboost training -num_local_round = 1 -params = { - "objective": "binary:logistic", - "eta": 0.1, # Learning rate - "max_depth": 8, - "eval_metric": "auc", - "nthread": 16, - "num_parallel_tree": 1, - "subsample": 1, - "tree_method": "hist", -} - - -# Define Flower client -class XgbClient(fl.client.Client): - def __init__( - self, - train_dmatrix, - valid_dmatrix, - num_train, - num_val, - num_local_round, - params, - ): - self.train_dmatrix = train_dmatrix - self.valid_dmatrix = valid_dmatrix - self.num_train = num_train - self.num_val = num_val - self.num_local_round = num_local_round - self.params = params - - def get_parameters(self, ins: GetParametersIns) -> GetParametersRes: - _ = (self, ins) - return GetParametersRes( - status=Status( - code=Code.OK, - message="OK", - ), - parameters=Parameters(tensor_type="", tensors=[]), - ) - - def _local_boost(self, bst_input): - # Update trees based on local training data. - for i in range(self.num_local_round): - bst_input.update(self.train_dmatrix, bst_input.num_boosted_rounds()) - - # Bagging: extract the last N=num_local_round trees for sever aggregation - bst = bst_input[ - bst_input.num_boosted_rounds() - - self.num_local_round : bst_input.num_boosted_rounds() - ] - - return bst - - def fit(self, ins: FitIns) -> FitRes: - global_round = int(ins.config["global_round"]) - if global_round == 1: - # First round local training - bst = xgb.train( - self.params, - self.train_dmatrix, - num_boost_round=self.num_local_round, - evals=[(self.valid_dmatrix, "validate"), (self.train_dmatrix, "train")], - ) - else: - bst = xgb.Booster(params=self.params) - for item in ins.parameters.tensors: - global_model = bytearray(item) - - # Load global model into booster - bst.load_model(global_model) - - # Local training - bst = self._local_boost(bst) - - # Save model - local_model = bst.save_raw("json") - local_model_bytes = bytes(local_model) - - return FitRes( - status=Status( - code=Code.OK, - message="OK", - ), - parameters=Parameters(tensor_type="", tensors=[local_model_bytes]), - num_examples=self.num_train, - metrics={}, - ) - - def evaluate(self, ins: EvaluateIns) -> EvaluateRes: - # Load global model - bst = xgb.Booster(params=self.params) - for para in ins.parameters.tensors: - para_b = bytearray(para) - bst.load_model(para_b) - - # Run evaluation - eval_results = bst.eval_set( - evals=[(self.valid_dmatrix, "valid")], - iteration=bst.num_boosted_rounds() - 1, - ) - auc = round(float(eval_results.split("\t")[1].split(":")[1]), 4) - - global_round = ins.config["global_round"] - log(INFO, f"AUC = {auc} at round {global_round}") - - return EvaluateRes( - status=Status( - code=Code.OK, - message="OK", - ), - loss=0.0, - num_examples=self.num_val, - metrics={"AUC": auc}, - ) - - -# Start Flower client -fl.client.start_client( - server_address="127.0.0.1:8080", - client=XgbClient( - train_dmatrix, - valid_dmatrix, - num_train, - num_val, - num_local_round, - params, - ).to_client(), -) diff --git a/examples/xgboost-quickstart/pyproject.toml b/examples/xgboost-quickstart/pyproject.toml index f1e451fe779a..da3561bfded4 100644 --- a/examples/xgboost-quickstart/pyproject.toml +++ b/examples/xgboost-quickstart/pyproject.toml @@ -3,17 +3,45 @@ requires = ["hatchling"] build-backend = "hatchling.build" [project] -name = "quickstart-xgboost" -version = "0.1.0" -description = "XGBoost Federated Learning Quickstart with Flower" -authors = [ - { name = "The Flower Authors", email = "hello@flower.ai" }, -] +name = "xgboost_quickstart" +version = "1.0.0" +description = "Federated Learning with XGBoost and Flower (Quickstart Example)" +license = "Apache-2.0" dependencies = [ - "flwr>=1.8.0,<2.0", - "flwr-datasets>=0.1.0,<1.0.0", - "xgboost>=2.0.0,<3.0.0", + "flwr-nightly[simulation]==1.11.0.dev20240826", + "flwr-datasets>=0.3.0", + "xgboost>=2.0.0", ] [tool.hatch.build.targets.wheel] packages = ["."] + +[tool.flwr.app] +publisher = "flwrlabs" + +[tool.flwr.app.components] +serverapp = "xgboost_quickstart.server_app:app" +clientapp = "xgboost_quickstart.client_app:app" + +[tool.flwr.app.config] +# ServerApp +num-server-rounds = 3 +fraction-fit = 0.1 +fraction-evaluate = 0.1 + +# ClientApp +local-epochs = 1 +params.objective = "binary:logistic" +params.eta = 0.1 # Learning rate +params.max-depth = 8 +params.eval-metric = "auc" +params.nthread = 16 +params.num-parallel-tree = 1 +params.subsample = 1 +params.tree-method = "hist" + +[tool.flwr.federations] +default = "local-simulation" + +[tool.flwr.federations.local-simulation] +options.num-supernodes = 20 diff --git a/examples/xgboost-quickstart/run.sh b/examples/xgboost-quickstart/run.sh deleted file mode 100755 index b35af58222ab..000000000000 --- a/examples/xgboost-quickstart/run.sh +++ /dev/null @@ -1,17 +0,0 @@ -#!/bin/bash -set -e -cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/ - -echo "Starting server" -python server.py & -sleep 5 # Sleep for 5s to give the server enough time to start - -for i in `seq 0 1`; do - echo "Starting client $i" - python3 client.py --partition-id=$i & -done - -# Enable CTRL+C to stop all background processes -trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM -# Wait for all background processes to complete -wait diff --git a/examples/xgboost-quickstart/server.py b/examples/xgboost-quickstart/server.py deleted file mode 100644 index 2246d32686a4..000000000000 --- a/examples/xgboost-quickstart/server.py +++ /dev/null @@ -1,48 +0,0 @@ -from typing import Dict - -import flwr as fl -from flwr.server.strategy import FedXgbBagging - -# FL experimental settings -pool_size = 2 -num_rounds = 5 -num_clients_per_round = 2 -num_evaluate_clients = 2 - - -def evaluate_metrics_aggregation(eval_metrics): - """Return an aggregated metric (AUC) for evaluation.""" - total_num = sum([num for num, _ in eval_metrics]) - auc_aggregated = ( - sum([metrics["AUC"] * num for num, metrics in eval_metrics]) / total_num - ) - metrics_aggregated = {"AUC": auc_aggregated} - return metrics_aggregated - - -def config_func(rnd: int) -> Dict[str, str]: - """Return a configuration with global epochs.""" - config = { - "global_round": str(rnd), - } - return config - - -# Define strategy -strategy = FedXgbBagging( - fraction_fit=(float(num_clients_per_round) / pool_size), - min_fit_clients=num_clients_per_round, - min_available_clients=pool_size, - min_evaluate_clients=num_evaluate_clients, - fraction_evaluate=1.0, - evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation, - on_evaluate_config_fn=config_func, - on_fit_config_fn=config_func, -) - -# Start Flower server -fl.server.start_server( - server_address="0.0.0.0:8080", - config=fl.server.ServerConfig(num_rounds=num_rounds), - strategy=strategy, -) diff --git a/examples/xgboost-quickstart/xgboost_quickstart/__init__.py b/examples/xgboost-quickstart/xgboost_quickstart/__init__.py new file mode 100644 index 000000000000..470360b377a6 --- /dev/null +++ b/examples/xgboost-quickstart/xgboost_quickstart/__init__.py @@ -0,0 +1 @@ +"""xgboost_quickstart: A Flower / XGBoost app.""" diff --git a/examples/xgboost-quickstart/xgboost_quickstart/client_app.py b/examples/xgboost-quickstart/xgboost_quickstart/client_app.py new file mode 100644 index 000000000000..3aa199a10274 --- /dev/null +++ b/examples/xgboost-quickstart/xgboost_quickstart/client_app.py @@ -0,0 +1,139 @@ +"""xgboost_quickstart: A Flower / XGBoost app.""" + +import warnings + +from flwr.common.context import Context + +import xgboost as xgb +from flwr.client import Client, ClientApp +from flwr.common.config import unflatten_dict +from flwr.common import ( + Code, + EvaluateIns, + EvaluateRes, + FitIns, + FitRes, + Parameters, + Status, +) + +from xgboost_quickstart.task import load_data, replace_keys + +warnings.filterwarnings("ignore", category=UserWarning) + + +# Define Flower Client and client_fn +class FlowerClient(Client): + def __init__( + self, + train_dmatrix, + valid_dmatrix, + num_train, + num_val, + num_local_round, + params, + ): + self.train_dmatrix = train_dmatrix + self.valid_dmatrix = valid_dmatrix + self.num_train = num_train + self.num_val = num_val + self.num_local_round = num_local_round + self.params = params + + def _local_boost(self, bst_input): + # Update trees based on local training data. + for i in range(self.num_local_round): + bst_input.update(self.train_dmatrix, bst_input.num_boosted_rounds()) + + # Bagging: extract the last N=num_local_round trees for sever aggregation + bst = bst_input[ + bst_input.num_boosted_rounds() + - self.num_local_round : bst_input.num_boosted_rounds() + ] + + return bst + + def fit(self, ins: FitIns) -> FitRes: + global_round = int(ins.config["global_round"]) + if global_round == 1: + # First round local training + bst = xgb.train( + self.params, + self.train_dmatrix, + num_boost_round=self.num_local_round, + evals=[(self.valid_dmatrix, "validate"), (self.train_dmatrix, "train")], + ) + else: + bst = xgb.Booster(params=self.params) + global_model = bytearray(ins.parameters.tensors[0]) + + # Load global model into booster + bst.load_model(global_model) + + # Local training + bst = self._local_boost(bst) + + # Save model + local_model = bst.save_raw("json") + local_model_bytes = bytes(local_model) + + return FitRes( + status=Status( + code=Code.OK, + message="OK", + ), + parameters=Parameters(tensor_type="", tensors=[local_model_bytes]), + num_examples=self.num_train, + metrics={}, + ) + + def evaluate(self, ins: EvaluateIns) -> EvaluateRes: + # Load global model + bst = xgb.Booster(params=self.params) + para_b = bytearray(ins.parameters.tensors[0]) + bst.load_model(para_b) + + # Run evaluation + eval_results = bst.eval_set( + evals=[(self.valid_dmatrix, "valid")], + iteration=bst.num_boosted_rounds() - 1, + ) + auc = round(float(eval_results.split("\t")[1].split(":")[1]), 4) + + return EvaluateRes( + status=Status( + code=Code.OK, + message="OK", + ), + loss=0.0, + num_examples=self.num_val, + metrics={"AUC": auc}, + ) + + +def client_fn(context: Context): + # Load model and data + partition_id = context.node_config["partition-id"] + num_partitions = context.node_config["num-partitions"] + train_dmatrix, valid_dmatrix, num_train, num_val = load_data( + partition_id, num_partitions + ) + + cfg = replace_keys(unflatten_dict(context.run_config)) + num_local_round = cfg["local_epochs"] + + # Return Client instance + return FlowerClient( + train_dmatrix, + valid_dmatrix, + num_train, + num_val, + num_local_round, + cfg["params"], + ) + + +# Flower ClientApp +app = ClientApp( + client_fn, +) diff --git a/examples/xgboost-quickstart/xgboost_quickstart/server_app.py b/examples/xgboost-quickstart/xgboost_quickstart/server_app.py new file mode 100644 index 000000000000..6b81c6caa785 --- /dev/null +++ b/examples/xgboost-quickstart/xgboost_quickstart/server_app.py @@ -0,0 +1,54 @@ +"""xgboost_quickstart: A Flower / XGBoost app.""" + +from typing import Dict + +from flwr.common import Context, Parameters +from flwr.server import ServerApp, ServerAppComponents, ServerConfig +from flwr.server.strategy import FedXgbBagging + + +def evaluate_metrics_aggregation(eval_metrics): + """Return an aggregated metric (AUC) for evaluation.""" + total_num = sum([num for num, _ in eval_metrics]) + auc_aggregated = ( + sum([metrics["AUC"] * num for num, metrics in eval_metrics]) / total_num + ) + metrics_aggregated = {"AUC": auc_aggregated} + return metrics_aggregated + + +def config_func(rnd: int) -> Dict[str, str]: + """Return a configuration with global epochs.""" + config = { + "global_round": str(rnd), + } + return config + + +def server_fn(context: Context): + # Read from config + num_rounds = context.run_config["num-server-rounds"] + fraction_fit = context.run_config["fraction-fit"] + fraction_evaluate = context.run_config["fraction-evaluate"] + + # Init an empty Parameter + parameters = Parameters(tensor_type="", tensors=[]) + + # Define strategy + strategy = FedXgbBagging( + fraction_fit=fraction_fit, + fraction_evaluate=fraction_evaluate, + evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation, + on_evaluate_config_fn=config_func, + on_fit_config_fn=config_func, + initial_parameters=parameters, + ) + config = ServerConfig(num_rounds=num_rounds) + + return ServerAppComponents(strategy=strategy, config=config) + + +# Create ServerApp +app = ServerApp( + server_fn=server_fn, +) diff --git a/examples/xgboost-quickstart/xgboost_quickstart/task.py b/examples/xgboost-quickstart/xgboost_quickstart/task.py new file mode 100644 index 000000000000..09916d9ac04a --- /dev/null +++ b/examples/xgboost-quickstart/xgboost_quickstart/task.py @@ -0,0 +1,71 @@ +"""xgboost_quickstart: A Flower / XGBoost app.""" + +from logging import INFO + +import xgboost as xgb +from flwr.common import log +from flwr_datasets import FederatedDataset +from flwr_datasets.partitioner import IidPartitioner + + +def train_test_split(partition, test_fraction, seed): + """Split the data into train and validation set given split rate.""" + train_test = partition.train_test_split(test_size=test_fraction, seed=seed) + partition_train = train_test["train"] + partition_test = train_test["test"] + + num_train = len(partition_train) + num_test = len(partition_test) + + return partition_train, partition_test, num_train, num_test + + +def transform_dataset_to_dmatrix(data): + """Transform dataset to DMatrix format for xgboost.""" + x = data["inputs"] + y = data["label"] + new_data = xgb.DMatrix(x, label=y) + return new_data + + +fds = None # Cache FederatedDataset + + +def load_data(partition_id, num_clients): + """Load partition HIGGS data.""" + # Only initialize `FederatedDataset` once + global fds + if fds is None: + partitioner = IidPartitioner(num_partitions=num_clients) + fds = FederatedDataset( + dataset="jxie/higgs", + partitioners={"train": partitioner}, + ) + + # Load the partition for this `partition_id` + partition = fds.load_partition(partition_id, split="train") + partition.set_format("numpy") + + # Train/test splitting + train_data, valid_data, num_train, num_val = train_test_split( + partition, test_fraction=0.2, seed=42 + ) + + # Reformat data to DMatrix for xgboost + log(INFO, "Reformatting data...") + train_dmatrix = transform_dataset_to_dmatrix(train_data) + valid_dmatrix = transform_dataset_to_dmatrix(valid_data) + + return train_dmatrix, valid_dmatrix, num_train, num_val + + +def replace_keys(input_dict, match="-", target="_"): + """Recursively replace match string with target string in dictionary keys.""" + new_dict = {} + for key, value in input_dict.items(): + new_key = key.replace(match, target) + if isinstance(value, dict): + new_dict[new_key] = replace_keys(value, match, target) + else: + new_dict[new_key] = value + return new_dict diff --git a/src/docker/base/alpine/Dockerfile b/src/docker/base/alpine/Dockerfile index 3e6a246e53c1..ee1e11b2d070 100644 --- a/src/docker/base/alpine/Dockerfile +++ b/src/docker/base/alpine/Dockerfile @@ -33,6 +33,8 @@ RUN apk add --no-cache \ # require for compiling grpcio on ARM64 g++ \ libffi-dev \ + # required for installing flwr via git + git \ # create virtual env && python -m venv /python/venv @@ -42,12 +44,19 @@ ENV PATH=/python/venv/bin:$PATH # Install specific version of pip, setuptools and flwr ARG PIP_VERSION ARG SETUPTOOLS_VERSION -ARG FLWR_VERSION -ARG FLWR_PACKAGE=flwr RUN pip install -U --no-cache-dir \ pip==${PIP_VERSION} \ - setuptools==${SETUPTOOLS_VERSION} \ - ${FLWR_PACKAGE}==${FLWR_VERSION} + setuptools==${SETUPTOOLS_VERSION} + +ARG FLWR_VERSION +ARG FLWR_VERSION_REF +ARG FLWR_PACKAGE=flwr +# hadolint ignore=DL3013 +RUN if [ -z "${FLWR_VERSION_REF}" ]; then \ + pip install -U --no-cache-dir ${FLWR_PACKAGE}==${FLWR_VERSION}; \ + else \ + pip install -U --no-cache-dir ${FLWR_PACKAGE}@${FLWR_VERSION_REF}; \ + fi FROM python:${PYTHON_VERSION}-${DISTRO}${DISTRO_VERSION} AS base diff --git a/src/docker/base/ubuntu/Dockerfile b/src/docker/base/ubuntu/Dockerfile index ddc662a0ae98..47655b1a52a1 100644 --- a/src/docker/base/ubuntu/Dockerfile +++ b/src/docker/base/ubuntu/Dockerfile @@ -60,12 +60,19 @@ RUN pip install -U --no-cache-dir pip==${PIP_VERSION} setuptools==${SETUPTOOLS_V && python -m venv /python/venv ENV PATH=/python/venv/bin:$PATH -ARG FLWR_VERSION -ARG FLWR_PACKAGE=flwr RUN pip install -U --no-cache-dir \ pip==${PIP_VERSION} \ - setuptools==${SETUPTOOLS_VERSION} \ - ${FLWR_PACKAGE}==${FLWR_VERSION} + setuptools==${SETUPTOOLS_VERSION} + +ARG FLWR_VERSION +ARG FLWR_VERSION_REF +ARG FLWR_PACKAGE=flwr +# hadolint ignore=DL3013 +RUN if [ -z "${FLWR_VERSION_REF}" ]; then \ + pip install -U --no-cache-dir ${FLWR_PACKAGE}==${FLWR_VERSION}; \ + else \ + pip install -U --no-cache-dir ${FLWR_PACKAGE}@${FLWR_VERSION_REF}; \ + fi FROM $DISTRO:$DISTRO_VERSION AS base diff --git a/src/docker/distributed/.gitignore b/src/docker/distributed/.gitignore new file mode 100644 index 000000000000..1a11330c6e95 --- /dev/null +++ b/src/docker/distributed/.gitignore @@ -0,0 +1,3 @@ +superexec-certificates +superlink-certificates +server/state diff --git a/src/docker/distributed/certs.yml b/src/docker/distributed/certs.yml new file mode 100644 index 000000000000..48e157582e40 --- /dev/null +++ b/src/docker/distributed/certs.yml @@ -0,0 +1,6 @@ +services: + gen-certs: + build: + args: + SUPERLINK_IP: ${SUPERLINK_IP:-127.0.0.1} + SUPEREXEC_IP: ${SUPEREXEC_IP:-127.0.0.1} diff --git a/src/docker/distributed/client/compose.yml b/src/docker/distributed/client/compose.yml new file mode 100644 index 000000000000..ef69e40cc425 --- /dev/null +++ b/src/docker/distributed/client/compose.yml @@ -0,0 +1,128 @@ +services: + supernode-1: + image: flwr/supernode:${FLWR_VERSION:-1.11.0} + command: + - --superlink + - ${SUPERLINK_IP:-127.0.0.1}:9092 + - --supernode-address + - 0.0.0.0:9094 + - --isolation + - process + - --node-config + - "partition-id=0 num-partitions=2" + - --root-certificates + - certificates/ca.crt + secrets: + - source: superlink-ca-certfile + target: /app/certificates/ca.crt + + supernode-2: + image: flwr/supernode:${FLWR_VERSION:-1.11.0} + command: + - --superlink + - ${SUPERLINK_IP:-127.0.0.1}:9092 + - --supernode-address + - 0.0.0.0:9095 + - --isolation + - process + - --node-config + - "partition-id=1 num-partitions=2" + - --root-certificates + - certificates/ca.crt + secrets: + - source: superlink-ca-certfile + target: /app/certificates/ca.crt + + # uncomment to add another SuperNode + # + # supernode-3: + # image: flwr/supernode:${FLWR_VERSION:-1.11.0} + # command: + # - --superlink + # - ${SUPERLINK_IP:-127.0.0.1}:9092 + # - --supernode-address + # - 0.0.0.0:9096 + # - --isolation + # - process + # - --node-config + # - "partition-id=1 num-partitions=2" + # - --root-certificates + # - certificates/ca.crt + # secrets: + # - source: superlink-ca-certfile + # target: /app/certificates/ca.crt + + clientapp-1: + build: + context: ${PROJECT_DIR:-.} + dockerfile_inline: | + FROM flwr/clientapp:${FLWR_VERSION:-1.11.0} + + WORKDIR /app + COPY --chown=app:app pyproject.toml . + RUN sed -i 's/.*flwr\[simulation\].*//' pyproject.toml \ + && python -m pip install -U --no-cache-dir . + + ENTRYPOINT ["flwr-clientapp"] + command: + - --supernode + - supernode-1:9094 + deploy: + resources: + limits: + cpus: "2" + stop_signal: SIGINT + depends_on: + - supernode-1 + + clientapp-2: + build: + context: ${PROJECT_DIR:-.} + dockerfile_inline: | + FROM flwr/clientapp:${FLWR_VERSION:-1.11.0} + + WORKDIR /app + COPY --chown=app:app pyproject.toml . + RUN sed -i 's/.*flwr\[simulation\].*//' pyproject.toml \ + && python -m pip install -U --no-cache-dir . + + ENTRYPOINT ["flwr-clientapp"] + command: + - --supernode + - supernode-2:9095 + deploy: + resources: + limits: + cpus: "2" + stop_signal: SIGINT + depends_on: + - supernode-2 + + # uncomment to add another ClientApp + # + # clientapp-3: + # build: + # context: ${PROJECT_DIR:-.} + # dockerfile_inline: | + # FROM flwr/clientapp:${FLWR_VERSION:-1.11.0} + + # WORKDIR /app + # COPY --chown=app:app pyproject.toml . + # RUN sed -i 's/.*flwr\[simulation\].*//' pyproject.toml \ + # && python -m pip install -U --no-cache-dir . + + # ENTRYPOINT ["flwr-clientapp"] + # command: + # - --supernode + # - supernode-3:9096 + # deploy: + # resources: + # limits: + # cpus: "2" + # stop_signal: SIGINT + # depends_on: + # - supernode-3 + +secrets: + superlink-ca-certfile: + file: ../superlink-certificates/ca.crt diff --git a/src/docker/distributed/server/compose.yml b/src/docker/distributed/server/compose.yml new file mode 100644 index 000000000000..fc6dd6f58717 --- /dev/null +++ b/src/docker/distributed/server/compose.yml @@ -0,0 +1,67 @@ +services: + superlink: + image: flwr/superlink:${FLWR_VERSION:-1.11.0} + command: + - --ssl-ca-certfile=certificates/ca.crt + - --ssl-certfile=certificates/server.pem + - --ssl-keyfile=certificates/server.key + - --database=state/state.db + volumes: + - ./state/:/app/state/:rw + secrets: + - source: superlink-ca-certfile + target: /app/certificates/ca.crt + - source: superlink-certfile + target: /app/certificates/server.pem + - source: superlink-keyfile + target: /app/certificates/server.key + ports: + - 9092:9092 + + superexec: + build: + context: ${PROJECT_DIR:-.} + dockerfile_inline: | + FROM flwr/superexec:${FLWR_VERSION:-1.11.0} + + WORKDIR /app + COPY --chown=app:app pyproject.toml . + RUN sed -i 's/.*flwr\[simulation\].*//' pyproject.toml \ + && python -m pip install -U --no-cache-dir . + + ENTRYPOINT ["flower-superexec"] + command: + - --executor + - flwr.superexec.deployment:executor + - --executor-config + - superlink="superlink:9091" root-certificates="certificates/superlink-ca.crt" + - --ssl-ca-certfile=certificates/ca.crt + - --ssl-certfile=certificates/server.pem + - --ssl-keyfile=certificates/server.key + secrets: + - source: superlink-ca-certfile + target: /app/certificates/superlink-ca.crt + - source: superexec-ca-certfile + target: /app/certificates/ca.crt + - source: superexec-certfile + target: /app/certificates/server.pem + - source: superexec-keyfile + target: /app/certificates/server.key + ports: + - 9093:9093 + depends_on: + - superlink + +secrets: + superlink-ca-certfile: + file: ../superlink-certificates/ca.crt + superlink-certfile: + file: ../superlink-certificates/server.pem + superlink-keyfile: + file: ../superlink-certificates/server.key + superexec-ca-certfile: + file: ../superexec-certificates/ca.crt + superexec-certfile: + file: ../superexec-certificates/server.pem + superexec-keyfile: + file: ../superexec-certificates/server.key diff --git a/src/py/flwr/client/client_app.py b/src/py/flwr/client/client_app.py index 2a913b3a248d..5e76acd1ddd8 100644 --- a/src/py/flwr/client/client_app.py +++ b/src/py/flwr/client/client_app.py @@ -41,11 +41,11 @@ def _alert_erroneous_client_fn() -> None: def _inspect_maybe_adapt_client_fn_signature(client_fn: ClientFnExt) -> ClientFnExt: client_fn_args = inspect.signature(client_fn).parameters - first_arg = list(client_fn_args.keys())[0] if len(client_fn_args) != 1: _alert_erroneous_client_fn() + first_arg = list(client_fn_args.keys())[0] first_arg_type = client_fn_args[first_arg].annotation if first_arg_type is str or first_arg == "cid":