-
Notifications
You must be signed in to change notification settings - Fork 941
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into fds-bump-up-datasets-version
- Loading branch information
Showing
12 changed files
with
315 additions
and
168 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
keys/ | ||
certificates/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""authexample.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
"""authexample: An authenticated Flower / PyTorch app.""" | ||
|
||
import torch | ||
from flwr.client import ClientApp, NumPyClient | ||
from flwr.common import Context | ||
|
||
from authexample.task import ( | ||
Net, | ||
get_weights, | ||
load_data_from_disk, | ||
set_weights, | ||
test, | ||
train, | ||
) | ||
|
||
|
||
# Define Flower Client | ||
class FlowerClient(NumPyClient): | ||
def __init__(self, trainloader, valloader, local_epochs, learning_rate): | ||
self.net = Net() | ||
self.trainloader = trainloader | ||
self.valloader = valloader | ||
self.local_epochs = local_epochs | ||
self.lr = learning_rate | ||
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
|
||
def fit(self, parameters, config): | ||
"""Train the model with data of this client.""" | ||
set_weights(self.net, parameters) | ||
results = train( | ||
self.net, | ||
self.trainloader, | ||
self.valloader, | ||
self.local_epochs, | ||
self.lr, | ||
self.device, | ||
) | ||
return get_weights(self.net), len(self.trainloader.dataset), results | ||
|
||
def evaluate(self, parameters, config): | ||
"""Evaluate the model on the data this client has.""" | ||
set_weights(self.net, parameters) | ||
loss, accuracy = test(self.net, self.valloader, self.device) | ||
return loss, len(self.valloader.dataset), {"accuracy": accuracy} | ||
|
||
|
||
def client_fn(context: Context): | ||
"""Construct a Client that will be run in a ClientApp.""" | ||
|
||
# Read the node_config to get the path to the dataset the SuperNode running | ||
# this ClientApp has access to | ||
dataset_path = context.node_config["dataset-path"] | ||
|
||
# Read run_config to fetch hyperparameters relevant to this run | ||
batch_size = context.run_config["batch-size"] | ||
trainloader, valloader = load_data_from_disk(dataset_path, batch_size) | ||
local_epochs = context.run_config["local-epochs"] | ||
learning_rate = context.run_config["learning-rate"] | ||
|
||
# Return Client instance | ||
return FlowerClient(trainloader, valloader, local_epochs, learning_rate).to_client() | ||
|
||
|
||
# Flower ClientApp | ||
app = ClientApp(client_fn) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
"""authexample: An authenticated Flower / PyTorch app.""" | ||
|
||
from typing import List, Tuple | ||
|
||
from flwr.common import Context, Metrics, ndarrays_to_parameters | ||
from flwr.server import ServerApp, ServerAppComponents, ServerConfig | ||
from flwr.server.strategy import FedAvg | ||
|
||
from authexample.task import Net, get_weights | ||
|
||
|
||
# Define metric aggregation function | ||
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: | ||
# Multiply accuracy of each client by number of examples used | ||
accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics] | ||
examples = [num_examples for num_examples, _ in metrics] | ||
|
||
# Aggregate and return custom metric (weighted average) | ||
return {"accuracy": sum(accuracies) / sum(examples)} | ||
|
||
|
||
def server_fn(context: Context): | ||
"""Construct components that set the ServerApp behaviour.""" | ||
|
||
# Read from config | ||
num_rounds = context.run_config["num-server-rounds"] | ||
|
||
# Initialize model parameters | ||
ndarrays = get_weights(Net()) | ||
parameters = ndarrays_to_parameters(ndarrays) | ||
|
||
# Define the strategy | ||
strategy = FedAvg( | ||
fraction_fit=1.0, | ||
fraction_evaluate=context.run_config["fraction-evaluate"], | ||
min_available_clients=2, | ||
evaluate_metrics_aggregation_fn=weighted_average, | ||
initial_parameters=parameters, | ||
) | ||
config = ServerConfig(num_rounds=num_rounds) | ||
|
||
return ServerAppComponents(strategy=strategy, config=config) | ||
|
||
|
||
# Create ServerApp | ||
app = ServerApp(server_fn=server_fn) |
Oops, something went wrong.