Skip to content

Commit

Permalink
Merge branch 'main' into patch-2
Browse files Browse the repository at this point in the history
  • Loading branch information
danieljanes authored Jan 17, 2024
2 parents 02a6e2c + 0daa3d7 commit a7b8f23
Show file tree
Hide file tree
Showing 64 changed files with 690 additions and 461 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/_docker-build.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Reusable docker server image build workflow
name: Reusable docker image build workflow

on:
workflow_call:
Expand Down Expand Up @@ -35,7 +35,7 @@ permissions:
# based on https://docs.docker.com/build/ci/github-actions/multi-platform/#distribute-build-across-multiple-runners
jobs:
build:
name: Build server image
name: Build image
runs-on: ubuntu-22.04
timeout-minutes: 60
outputs:
Expand Down Expand Up @@ -98,7 +98,7 @@ jobs:
touch "/tmp/digests/${digest#sha256:}"
- name: Upload digest
uses: actions/upload-artifact@c7d193f32edcb7bfad88892161225aeda64e9392 # v4.0.0
uses: actions/upload-artifact@1eb3cb2b3e0f29609092a73eb033bb759a334595 # v4.1.0
with:
name: digests-${{ steps.build-id.outputs.id }}-${{ matrix.platform.name }}
path: /tmp/digests/*
Expand All @@ -114,7 +114,7 @@ jobs:
metadata: ${{ steps.meta.outputs.json }}
steps:
- name: Download digests
uses: actions/download-artifact@f44cd7b40bfd40b6aa1cc1b9b5b7bf03d3c67110 # v4.1.0
uses: actions/download-artifact@6b208ae046db98c579e8a3aa621ab581ff575935 # v4.1.1
with:
pattern: digests-${{ needs.build.outputs.build-id }}-*
path: /tmp/digests
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/docker-base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
echo "ubuntu-version=${{ env.DEFAULT_UBUNTU }}" >> "$GITHUB_OUTPUT"
build-base-images:
name: Build images
name: Build base images
uses: ./.github/workflows/_docker-build.yml
needs: parameters
strategy:
Expand Down
36 changes: 36 additions & 0 deletions .github/workflows/docker-client.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
name: Build docker client image

on:
workflow_dispatch:
inputs:
flwr-version:
description: "Version of Flower e.g. (1.6.0)."
required: true
type: string

permissions:
contents: read

jobs:
build-client-images:
name: Build client images
uses: ./.github/workflows/_docker-build.yml
# run only on default branch when using it with workflow_dispatch
if: github.ref_name == github.event.repository.default_branch
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]
with:
namespace-repository: flwr/client
file-dir: src/docker/client
build-args: |
FLWR_VERSION=${{ github.event.inputs.flwr-version }}
BASE_IMAGE_TAG=py${{ matrix.python-version }}-ubuntu22.04
tags: |
${{ github.event.inputs.flwr-version }}-py${{ matrix.python-version }}-ubuntu22.04
${{ github.event.inputs.flwr-version }}
latest
secrets:
dockerhub-user: ${{ secrets.DOCKERHUB_USERNAME }}
dockerhub-token: ${{ secrets.DOCKERHUB_TOKEN }}
2 changes: 1 addition & 1 deletion baselines/hfedxgboost/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dataset: [a9a, cod-rna, ijcnn1, space_ga, cpusmall, YearPredictionMSD]
**Paper:** [arxiv.org/abs/2304.07537](https://arxiv.org/abs/2304.07537)

**Authors:** Chenyang Ma, Xinchi Qiu, Daniel J. Beutel, Nicholas D. Laneearly_stop_patience_rounds: 100
**Authors:** Chenyang Ma, Xinchi Qiu, Daniel J. Beutel, Nicholas D. Lane

**Abstract:** The privacy-sensitive nature of decentralized datasets and the robustness of eXtreme Gradient Boosting (XGBoost) on tabular data raise the need to train XGBoost in the context of federated learning (FL). Existing works on federated XGBoost in the horizontal setting rely on the sharing of gradients, which induce per-node level communication frequency and serious privacy concerns. To alleviate these problems, we develop an innovative framework for horizontal federated XGBoost which does not depend on the sharing of gradients and simultaneously boosts privacy and communication efficiency by making the learning rates of the aggregated tree ensembles are learnable. We conduct extensive evaluations on various classification and regression datasets, showing our approach achieve performance comparable to the state-of-the-art method and effectively improves communication efficiency by lowering both communication rounds and communication overhead by factors ranging from 25x to 700x.

Expand Down
2 changes: 1 addition & 1 deletion datasets/e2e/tensorflow/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ description = "Flower Datasets with TensorFlow"
authors = ["The Flower Authors <[email protected]>"]

[tool.poetry.dependencies]
python = "^3.8"
python = ">=3.8,<3.11"
flwr-datasets = { path = "./../../", extras = ["vision"] }
tensorflow-cpu = "^2.9.1, !=2.11.1"
parameterized = "==0.9.0"
22 changes: 22 additions & 0 deletions doc/source/how-to-install-flower.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ Flower requires at least `Python 3.8 <https://docs.python.org/3.8/>`_, but `Pyth
Install stable release
----------------------

Using pip
~~~~~~~~~

Stable releases are available on `PyPI <https://pypi.org/project/flwr/>`_::

python -m pip install flwr
Expand All @@ -20,6 +23,25 @@ For simulations that use the Virtual Client Engine, ``flwr`` should be installed
python -m pip install flwr[simulation]


Using conda (or mamba)
~~~~~~~~~~~~~~~~~~~~~~

Flower can also be installed from the ``conda-forge`` channel.

If you have not added ``conda-forge`` to your channels, you will first need to run the following::

conda config --add channels conda-forge
conda config --set channel_priority strict

Once the ``conda-forge`` channel has been enabled, ``flwr`` can be installed with ``conda``::

conda install flwr

or with ``mamba``::

mamba install flwr


Verify installation
-------------------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@
" min_available_clients=10, # Wait until all 10 clients are available\n",
")\n",
"\n",
"# Specify the resources each of your clients need. By default, each \n",
"# Specify the resources each of your clients need. By default, each\n",
"# client will be allocated 1x CPU and 0x CPUs\n",
"client_resources = {\"num_cpus\": 1, \"num_gpus\": 0.0}\n",
"if DEVICE.type == \"cuda\":\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/android/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,4 @@ poetry run ./run.sh

Download and install the `flwr_android_client.apk` on each Android device/emulator. The server currently expects a minimum of 4 Android clients, but it can be changed in the `server.py`.

When the Android app runs, add the client ID (between 1-10), the IP and port of your server, and press `Load Dataset`. This will load the local CIFAR10 dataset in memory. Then press `Setup Connection Channel` which will establish connection with the server. Finally, press `Train Federated!` which will start the federated training.
When the Android app runs, add the client ID (between 1-10), the IP and port of your server, and press `Start`. This will load the local CIFAR10 dataset in memory, establish connection with the server, and start the federated training. To abort the federated learning process, press `Stop`. You can clear and refresh the log messages by pressing `Clear` and `Refresh` buttons respectively.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def main(cfg: DictConfig):
save_path = HydraConfig.get().runtime.output_dir

## 2. Prepare your dataset
# When simulating FL workloads we have a lot of freedom on how the FL clients behave,
# When simulating FL runs we have a lot of freedom on how the FL clients behave,
# what data they have, how much data, etc. This is not possible in real FL settings.
# In simulation you'd often encounter two types of dataset:
# * naturally partitioned, that come pre-partitioned by user id (e.g. FEMNIST,
Expand Down Expand Up @@ -91,7 +91,7 @@ def main(cfg: DictConfig):
"num_gpus": 0.0,
}, # (optional) controls the degree of parallelism of your simulation.
# Lower resources per client allow for more clients to run concurrently
# (but need to be set taking into account the compute/memory footprint of your workload)
# (but need to be set taking into account the compute/memory footprint of your run)
# `num_cpus` is an absolute number (integer) indicating the number of threads a client should be allocated
# `num_gpus` is a ratio indicating the portion of gpu memory that a client needs.
)
Expand Down
12 changes: 6 additions & 6 deletions examples/mt-pytorch/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:

# -------------------------------------------------------------------------- Driver SDK
driver.connect()
create_workload_res: driver_pb2.CreateWorkloadResponse = driver.create_workload(
req=driver_pb2.CreateWorkloadRequest()
create_run_res: driver_pb2.CreateRunResponse = driver.create_run(
req=driver_pb2.CreateRunRequest()
)
# -------------------------------------------------------------------------- Driver SDK

workload_id = create_workload_res.workload_id
print(f"Created workload id {workload_id}")
run_id = create_run_res.run_id
print(f"Created run id {run_id}")

history = History()
for server_round in range(num_rounds):
Expand Down Expand Up @@ -93,7 +93,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
# loop and wait until enough client nodes are available.
while True:
# Get a list of node ID's from the server
get_nodes_req = driver_pb2.GetNodesRequest(workload_id=workload_id)
get_nodes_req = driver_pb2.GetNodesRequest(run_id=run_id)

# ---------------------------------------------------------------------- Driver SDK
get_nodes_res: driver_pb2.GetNodesResponse = driver.get_nodes(
Expand Down Expand Up @@ -125,7 +125,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
new_task_ins = task_pb2.TaskIns(
task_id="", # Do not set, will be created and set by the DriverAPI
group_id="",
workload_id=workload_id,
run_id=run_id,
task=task_pb2.Task(
producer=node_pb2.Node(
node_id=0,
Expand Down
14 changes: 7 additions & 7 deletions examples/pytorch-from-centralized-to-federated/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ def apply_transforms(batch):


def train(
net: Net,
trainloader: torch.utils.data.DataLoader,
epochs: int,
device: torch.device, # pylint: disable=no-member
net: Net,
trainloader: torch.utils.data.DataLoader,
epochs: int,
device: torch.device, # pylint: disable=no-member
) -> None:
"""Train the network."""
# Define loss and optimizer
Expand Down Expand Up @@ -110,9 +110,9 @@ def train(


def test(
net: Net,
testloader: torch.utils.data.DataLoader,
device: torch.device, # pylint: disable=no-member
net: Net,
testloader: torch.utils.data.DataLoader,
device: torch.device, # pylint: disable=no-member
) -> Tuple[float, float]:
"""Validate the network on the entire test set."""
# Define loss and metrics
Expand Down
12 changes: 6 additions & 6 deletions examples/pytorch-from-centralized-to-federated/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ class CifarClient(fl.client.NumPyClient):
"""Flower client implementing CIFAR-10 image classification using PyTorch."""

def __init__(
self,
model: cifar.Net,
trainloader: DataLoader,
testloader: DataLoader,
self,
model: cifar.Net,
trainloader: DataLoader,
testloader: DataLoader,
) -> None:
self.model = model
self.trainloader = trainloader
Expand Down Expand Up @@ -61,15 +61,15 @@ def set_parameters(self, parameters: List[np.ndarray]) -> None:
self.model.load_state_dict(state_dict, strict=True)

def fit(
self, parameters: List[np.ndarray], config: Dict[str, str]
self, parameters: List[np.ndarray], config: Dict[str, str]
) -> Tuple[List[np.ndarray], int, Dict]:
# Set model parameters, train model, return updated model parameters
self.set_parameters(parameters)
cifar.train(self.model, self.trainloader, epochs=1, device=DEVICE)
return self.get_parameters(config={}), len(self.trainloader.dataset), {}

def evaluate(
self, parameters: List[np.ndarray], config: Dict[str, str]
self, parameters: List[np.ndarray], config: Dict[str, str]
) -> Tuple[float, int, Dict]:
# Set model parameters, evaluate model on local test dataset, return result
self.set_parameters(parameters)
Expand Down
2 changes: 1 addition & 1 deletion examples/quickstart-pytorch-lightning/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

disable_progress_bar()


class FlowerClient(fl.client.NumPyClient):
def __init__(self, model, train_loader, val_loader, test_loader):
self.model = model
Expand Down Expand Up @@ -55,7 +56,6 @@ def _set_parameters(model, parameters):


def main() -> None:

parser = argparse.ArgumentParser(description="Flower")
parser.add_argument(
"--node-id",
Expand Down
16 changes: 10 additions & 6 deletions examples/quickstart-pytorch-lightning/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,20 @@ def load_data(partition):
# 60 % for the federated train and 20 % for the federated validation (both in fit)
partition_train_valid = partition_full["train"].train_test_split(train_size=0.75)
trainloader = DataLoader(
partition_train_valid["train"], batch_size=32,
shuffle=True, collate_fn=collate_fn, num_workers=1
partition_train_valid["train"],
batch_size=32,
shuffle=True,
collate_fn=collate_fn,
num_workers=1,
)
valloader = DataLoader(
partition_train_valid["test"], batch_size=32,
collate_fn=collate_fn, num_workers=1
partition_train_valid["test"],
batch_size=32,
collate_fn=collate_fn,
num_workers=1,
)
testloader = DataLoader(
partition_full["test"], batch_size=32,
collate_fn=collate_fn, num_workers=1
partition_full["test"], batch_size=32, collate_fn=collate_fn, num_workers=1
)
return trainloader, valloader, testloader

Expand Down
4 changes: 3 additions & 1 deletion examples/quickstart-sklearn-tabular/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,6 @@ def evaluate(self, parameters, config): # type: ignore
return loss, len(X_test), {"test_accuracy": accuracy}

# Start Flower client
fl.client.start_client(server_address="0.0.0.0:8080", client=IrisClient().to_client())
fl.client.start_client(
server_address="0.0.0.0:8080", client=IrisClient().to_client()
)
13 changes: 7 additions & 6 deletions examples/secaggplus-mt/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def merge(_task: task_pb2.Task, _merge_task: task_pb2.Task) -> task_pb2.Task:
task_pb2.TaskIns(
task_id="", # Do not set, will be created and set by the DriverAPI
group_id="",
workload_id=workload_id,
run_id=run_id,
run_id=run_id,
task=merge(
task,
task_pb2.Task(
Expand Down Expand Up @@ -84,13 +85,13 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:

# -------------------------------------------------------------------------- Driver SDK
driver.connect()
create_workload_res: driver_pb2.CreateWorkloadResponse = driver.create_workload(
req=driver_pb2.CreateWorkloadRequest()
create_run_res: driver_pb2.CreateRunResponse = driver.create_run(
req=driver_pb2.CreateRunRequest()
)
# -------------------------------------------------------------------------- Driver SDK

workload_id = create_workload_res.workload_id
print(f"Created workload id {workload_id}")
run_id = create_run_res.run_id
print(f"Created run id {run_id}")

history = History()
for server_round in range(num_rounds):
Expand Down Expand Up @@ -119,7 +120,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
# loop and wait until enough client nodes are available.
while True:
# Get a list of node ID's from the server
get_nodes_req = driver_pb2.GetNodesRequest(workload_id=workload_id)
get_nodes_req = driver_pb2.GetNodesRequest(run_id=run_id)

# ---------------------------------------------------------------------- Driver SDK
get_nodes_res: driver_pb2.GetNodesResponse = driver.get_nodes(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ rest = ["requests", "starlette", "uvicorn"]
types-dataclasses = "==0.6.6"
types-protobuf = "==3.19.18"
types-requests = "==2.31.0.10"
types-setuptools = "==68.2.0.0"
types-setuptools = "==69.0.0.20240115"
clang-format = "==17.0.4"
isort = "==5.12.0"
black = { version = "==23.10.1", extras = ["jupyter"] }
Expand Down
8 changes: 8 additions & 0 deletions src/docker/client/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright 2023 Flower Labs GmbH. All Rights Reserved.

ARG BASE_REPOSITORY=flwr/base
ARG BASE_IMAGE_TAG
FROM $BASE_REPOSITORY:$BASE_IMAGE_TAG

ARG FLWR_VERSION
RUN python -m pip install -U --no-cache-dir flwr[rest]==${FLWR_VERSION}
12 changes: 6 additions & 6 deletions src/proto/flwr/proto/driver.proto
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ import "flwr/proto/node.proto";
import "flwr/proto/task.proto";

service Driver {
// Request workload_id
rpc CreateWorkload(CreateWorkloadRequest) returns (CreateWorkloadResponse) {}
// Request run_id
rpc CreateRun(CreateRunRequest) returns (CreateRunResponse) {}

// Return a set of nodes
rpc GetNodes(GetNodesRequest) returns (GetNodesResponse) {}
Expand All @@ -34,12 +34,12 @@ service Driver {
rpc PullTaskRes(PullTaskResRequest) returns (PullTaskResResponse) {}
}

// CreateWorkload
message CreateWorkloadRequest {}
message CreateWorkloadResponse { sint64 workload_id = 1; }
// CreateRun
message CreateRunRequest {}
message CreateRunResponse { sint64 run_id = 1; }

// GetNodes messages
message GetNodesRequest { sint64 workload_id = 1; }
message GetNodesRequest { sint64 run_id = 1; }
message GetNodesResponse { repeated Node nodes = 1; }

// PushTaskIns messages
Expand Down
Loading

0 comments on commit a7b8f23

Please sign in to comment.