Skip to content

Commit

Permalink
Merge branch 'main' into fds-rename-size-partitioner
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq authored Sep 11, 2024
2 parents 648fd7b + 570eb0f commit 87323fa
Show file tree
Hide file tree
Showing 13 changed files with 163 additions and 109 deletions.
2 changes: 1 addition & 1 deletion dev/changelog_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

type = ["ci", "docs", "feat", "fix", "refactor", "break"]

project = ["framework", "baselines", "datasets", "examples", "benchmarks"]
project = ["framework", "baselines", "datasets", "examples", "benchmarks", "glossary"]

scope = "skip"

Expand Down
10 changes: 5 additions & 5 deletions doc/source/how-to-authenticate-supernodes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ Use the following terminal command to start a Flower :code:`SuperNode` that has
.. code-block:: bash
flower-superlink
--ssl-ca-certfile certificates/ca.crt
--ssl-certfile certificates/server.pem
--ssl-ca-certfile certificates/ca.crt
--ssl-certfile certificates/server.pem
--ssl-keyfile certificates/server.key
--auth-list-public-keys keys/client_public_keys.csv
--auth-superlink-private-key keys/server_credentials
--auth-superlink-public-key keys/server_credentials.pub
Let's break down the authentication flags:

1. The first flag :code:`--auth-list-public-keys` expects a path to a CSV file storing all known node public keys. You need to store all known node public keys that are allowed to participate in a federation in one CSV file (:code:`.csv`).
Expand All @@ -56,8 +56,8 @@ Similar to the long-running Flower server (:code:`SuperLink`), you can easily en
Use the following terminal command to start an authenticated :code:`SuperNode`:

.. code-block:: bash
flower-client-app client:app
flower-supernode
--root-certificates certificates/ca.crt
--superlink 127.0.0.1:9092
--auth-supernode-private-key keys/client_credentials
Expand Down
10 changes: 5 additions & 5 deletions doc/source/how-to-enable-ssl-connections.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Enable SSL connections
This guide describes how to a SSL-enabled secure Flower server (:code:`SuperLink`) can be started and
how a Flower client (:code:`SuperNode`) can establish a secure connections to it.

A complete code example demonstrating a secure connection can be found
A complete code example demonstrating a secure connection can be found
`here <https://github.com/adap/flower/tree/main/examples/advanced-tensorflow>`_.

The code example comes with a :code:`README.md` file which explains how to start it. Although it is
Expand Down Expand Up @@ -42,9 +42,9 @@ Use the following terminal command to start a sever (SuperLink) that uses the pr

.. code-block:: bash
flower-superlink
--ssl-ca-certfile certificates/ca.crt
--ssl-certfile certificates/server.pem
flower-superlink
--ssl-ca-certfile certificates/ca.crt
--ssl-certfile certificates/server.pem
--ssl-keyfile certificates/server.key
When providing certificates, the server expects a tuple of three certificates paths: CA certificate, server certificate and server private key.
Expand All @@ -57,7 +57,7 @@ Use the following terminal command to start a client (SuperNode) that uses the p

.. code-block:: bash
flower-client-app client:app
flower-supernode
--root-certificates certificates/ca.crt
--superlink 127.0.0.1:9092
Expand Down
15 changes: 15 additions & 0 deletions src/proto/flwr/proto/clientappio.proto
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
// Copyright 2024 Flower Labs GmbH. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ==============================================================================

syntax = "proto3";

package flwr.proto;
Expand Down
15 changes: 15 additions & 0 deletions src/proto/flwr/proto/fab.proto
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
// Copyright 2024 Flower Labs GmbH. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ==============================================================================

syntax = "proto3";

package flwr.proto;
Expand Down
59 changes: 23 additions & 36 deletions src/py/flwr/cli/new/new.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,36 +136,23 @@ def new(
username = prompt_text("Please provide your Flower username")

if framework is not None:
framework_str_upper = str(framework.value)
framework_str = str(framework.value)
else:
framework_value = prompt_options(
framework_str = prompt_options(
"Please select ML framework by typing in the number",
[mlf.value for mlf in MlFramework],
)
selected_value = [
name
for name, value in vars(MlFramework).items()
if value == framework_value
]
framework_str_upper = selected_value[0]

framework_str = framework_str_upper.lower()

llm_challenge_str = None
if framework_str == "flowertune":
if framework_str == MlFramework.FLOWERTUNE:
llm_challenge_value = prompt_options(
"Please select LLM challenge by typing in the number",
sorted([challenge.value for challenge in LlmChallengeName]),
)
selected_value = [
name
for name, value in vars(LlmChallengeName).items()
if value == llm_challenge_value
]
llm_challenge_str = selected_value[0]
llm_challenge_str = llm_challenge_str.lower()
llm_challenge_str = llm_challenge_value.lower()

is_baseline_project = framework_str == "baseline"
if framework_str == MlFramework.BASELINE:
framework_str = "baseline"

print(
typer.style(
Expand All @@ -176,19 +163,21 @@ def new(
)

context = {
"framework_str": framework_str_upper,
"framework_str": framework_str,
"import_name": import_name.replace("-", "_"),
"package_name": package_name,
"project_name": app_name,
"username": username,
}

template_name = framework_str.lower()

# List of files to render
if llm_challenge_str:
files = {
".gitignore": {"template": "app/.gitignore.tpl"},
"pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"},
"README.md": {"template": f"app/README.{framework_str}.md.tpl"},
"pyproject.toml": {"template": f"app/pyproject.{template_name}.toml.tpl"},
"README.md": {"template": f"app/README.{template_name}.md.tpl"},
f"{import_name}/__init__.py": {"template": "app/code/__init__.py.tpl"},
f"{import_name}/server_app.py": {
"template": "app/code/flwr_tune/server_app.py.tpl"
Expand Down Expand Up @@ -235,44 +224,42 @@ def new(
files = {
".gitignore": {"template": "app/.gitignore.tpl"},
"README.md": {"template": "app/README.md.tpl"},
"pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"},
"pyproject.toml": {"template": f"app/pyproject.{template_name}.toml.tpl"},
f"{import_name}/__init__.py": {"template": "app/code/__init__.py.tpl"},
f"{import_name}/server_app.py": {
"template": f"app/code/server.{framework_str}.py.tpl"
"template": f"app/code/server.{template_name}.py.tpl"
},
f"{import_name}/client_app.py": {
"template": f"app/code/client.{framework_str}.py.tpl"
"template": f"app/code/client.{template_name}.py.tpl"
},
}

# Depending on the framework, generate task.py file
frameworks_with_tasks = [
MlFramework.PYTORCH.value.lower(),
MlFramework.JAX.value.lower(),
MlFramework.HUGGINGFACE.value.lower(),
MlFramework.MLX.value.lower(),
MlFramework.TENSORFLOW.value.lower(),
MlFramework.PYTORCH.value,
MlFramework.JAX.value,
MlFramework.HUGGINGFACE.value,
MlFramework.MLX.value,
MlFramework.TENSORFLOW.value,
]
if framework_str in frameworks_with_tasks:
files[f"{import_name}/task.py"] = {
"template": f"app/code/task.{framework_str}.py.tpl"
"template": f"app/code/task.{template_name}.py.tpl"
}

if is_baseline_project:
if framework_str == "baseline":
# Include additional files for baseline template
for file_name in ["model", "dataset", "strategy", "utils", "__init__"]:
files[f"{import_name}/{file_name}.py"] = {
"template": f"app/code/{file_name}.{framework_str}.py.tpl"
"template": f"app/code/{file_name}.{template_name}.py.tpl"
}

# Replace README.md
files["README.md"]["template"] = f"app/README.{framework_str}.md.tpl"
files["README.md"]["template"] = f"app/README.{template_name}.md.tpl"

# Add LICENSE
files["LICENSE"] = {"template": "app/LICENSE.tpl"}

context["framework_str"] = "baseline"

for file_path, value in files.items():
render_and_create(
file_path=project_dir / file_path,
Expand Down
48 changes: 19 additions & 29 deletions src/py/flwr/cli/new/templates/app/code/client.huggingface.py.tpl
Original file line number Diff line number Diff line change
@@ -1,18 +1,11 @@
"""$project_name: A Flower / $framework_str app."""

import torch
from flwr.client import ClientApp, NumPyClient
from flwr.common import Context
from transformers import AutoModelForSequenceClassification

from $import_name.task import (
get_weights,
load_data,
set_weights,
train,
test,
CHECKPOINT,
DEVICE,
)
from $import_name.task import get_weights, load_data, set_weights, test, train


# Flower client
Expand All @@ -22,37 +15,34 @@ class FlowerClient(NumPyClient):
self.trainloader = trainloader
self.testloader = testloader
self.local_epochs = local_epochs

def get_parameters(self, config):
return get_weights(self.net)

def set_parameters(self, parameters):
set_weights(self.net, parameters)
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.net.to(self.device)

def fit(self, parameters, config):
self.set_parameters(parameters)
train(
self.net,
self.trainloader,
epochs=self.local_epochs,
)
return self.get_parameters(config={}), len(self.trainloader), {}
set_weights(self.net, parameters)
train(self.net, self.trainloader, epochs=self.local_epochs, device=self.device)
return get_weights(self.net), len(self.trainloader), {}

def evaluate(self, parameters, config):
self.set_parameters(parameters)
loss, accuracy = test(self.net, self.testloader)
set_weights(self.net, parameters)
loss, accuracy = test(self.net, self.testloader, self.device)
return float(loss), len(self.testloader), {"accuracy": accuracy}


def client_fn(context: Context):
# Load model and data
net = AutoModelForSequenceClassification.from_pretrained(
CHECKPOINT, num_labels=2
).to(DEVICE)

# Get this client's dataset partition
partition_id = context.node_config["partition-id"]
num_partitions = context.node_config["num-partitions"]
trainloader, valloader = load_data(partition_id, num_partitions)
model_name = context.run_config["model-name"]
trainloader, valloader = load_data(partition_id, num_partitions, model_name)

# Load model
num_labels = context.run_config["num-labels"]
net = AutoModelForSequenceClassification.from_pretrained(
model_name, num_labels=num_labels
)

local_epochs = context.run_config["local-epochs"]

# Return Client instance
Expand Down
21 changes: 18 additions & 3 deletions src/py/flwr/cli/new/templates/app/code/server.huggingface.py.tpl
Original file line number Diff line number Diff line change
@@ -1,18 +1,33 @@
"""$project_name: A Flower / $framework_str app."""

from flwr.common import Context
from flwr.server.strategy import FedAvg
from flwr.common import Context, ndarrays_to_parameters
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from flwr.server.strategy import FedAvg
from transformers import AutoModelForSequenceClassification

from $import_name.task import get_weights


def server_fn(context: Context):
# Read from config
num_rounds = context.run_config["num-server-rounds"]
fraction_fit = context.run_config["fraction-fit"]

# Initialize global model
model_name = context.run_config["model-name"]
num_labels = context.run_config["num-labels"]
net = AutoModelForSequenceClassification.from_pretrained(
model_name, num_labels=num_labels
)

weights = get_weights(net)
initial_parameters = ndarrays_to_parameters(weights)

# Define strategy
strategy = FedAvg(
fraction_fit=1.0,
fraction_fit=fraction_fit,
fraction_evaluate=1.0,
initial_parameters=initial_parameters,
)
config = ServerConfig(num_rounds=num_rounds)

Expand Down
Loading

0 comments on commit 87323fa

Please sign in to comment.