-
Notifications
You must be signed in to change notification settings - Fork 35
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for push-based algos (#126)
* Add support for pushing models * Fix no messages for a round and add generator for node sampling
- Loading branch information
Showing
12 changed files
with
376 additions
and
36 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import random | ||
from collections import OrderedDict | ||
from typing import Any, Dict, List | ||
from torch import Tensor | ||
from utils.communication.comm_utils import CommunicationManager | ||
from algos.fl import FedAvgClient, FedAvgServer | ||
import time | ||
|
||
# import the possible attacks | ||
from algos.attack_add_noise import AddNoiseAttack | ||
from algos.attack_bad_weights import BadWeightsAttack | ||
from algos.attack_sign_flip import SignFlipAttack | ||
|
||
class FedAvgPushClient(FedAvgClient): | ||
def __init__( | ||
self, config: Dict[str, Any], comm_utils: CommunicationManager | ||
) -> None: | ||
super().__init__(config, comm_utils) | ||
|
||
def run_protocol(self): | ||
stats: Dict[str, Any] = {} | ||
print(f"Client {self.node_id} ready to start training") | ||
|
||
start_rounds = self.config.get("start_rounds", 0) | ||
total_rounds = self.config["rounds"] | ||
|
||
for round in range(start_rounds, total_rounds): | ||
# Fetch model from the server | ||
self.receive_pushed_and_aggregate() | ||
|
||
stats["train_loss"], stats["train_acc"], stats["train_time"] = self.local_train(round) | ||
stats["test_loss"], stats["test_acc"], stats["test_time"] = self.local_test() | ||
|
||
# Send the model to the server | ||
self.push(self.server_node) | ||
|
||
stats["bytes_received"], stats["bytes_sent"] = self.comm_utils.get_comm_cost() | ||
|
||
self.log_metrics(stats=stats, iteration=round) | ||
|
||
self.local_round_done() | ||
|
||
|
||
class FedAvgPushServer(FedAvgServer): | ||
def __init__( | ||
self, config: Dict[str, Any], comm_utils: CommunicationManager | ||
) -> None: | ||
super().__init__(config, comm_utils) | ||
|
||
def single_round(self): | ||
""" | ||
Runs the whole training procedure | ||
""" | ||
self.push(self.users) | ||
self.receive_pushed_and_aggregate() | ||
|
||
def receive_pushed_and_aggregate(self): | ||
reprs = self.comm_utils.all_gather_pushed() | ||
avg_wts = self.aggregate(reprs) | ||
self.set_representation(avg_wts) | ||
|
||
def run_protocol(self): | ||
stats: Dict[str, Any] = {} | ||
print(f"Client {self.node_id} ready to start training") | ||
start_rounds = self.config.get("start_rounds", 0) | ||
total_rounds = self.config["rounds"] | ||
for round in range(start_rounds, total_rounds): | ||
self.single_round() | ||
stats["bytes_received"], stats["bytes_sent"] = self.comm_utils.get_comm_cost() | ||
stats["test_loss"], stats["test_acc"], stats["test_time"] = self.test() | ||
self.log_metrics(stats=stats, iteration=round) | ||
self.local_round_done() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
""" | ||
Module for FedStaticClient and FedStaticServer in Federated Learning. | ||
""" | ||
from typing import Any, Dict, OrderedDict | ||
from utils.communication.comm_utils import CommunicationManager | ||
import torch | ||
import time | ||
|
||
from algos.fl_static import FedStaticNode, FedStaticServer | ||
from algos.topologies.collections import select_topology | ||
|
||
|
||
class SwiftNode(FedStaticNode): | ||
""" | ||
Federated Static Client Class. | ||
""" | ||
|
||
def __init__( | ||
self, config: Dict[str, Any], comm_utils: CommunicationManager | ||
) -> None: | ||
super().__init__(config, comm_utils) | ||
|
||
def run_protocol(self) -> None: | ||
""" | ||
Runs the federated learning protocol for the client. | ||
""" | ||
stats: Dict[str, Any] = {} | ||
print(f"Client {self.node_id} ready to start training") | ||
start_round = self.config.get("start_round", 0) | ||
if start_round != 0: | ||
raise NotImplementedError( | ||
"Start round different from 0 not implemented yet" | ||
) | ||
total_rounds = self.config["rounds"] | ||
epochs_per_round = self.config.get("epochs_per_round", 1) | ||
for it in range(start_round, total_rounds): | ||
# Train locally and send the representation to the server | ||
stats["train_loss"], stats["train_acc"], stats["train_time"] = self.local_train( | ||
it, epochs_per_round | ||
) | ||
|
||
# Collect the representations from all other nodes from the server | ||
neighbors = self.topology.sample_neighbours(self.num_collaborators) | ||
# TODO: Log the neighbors | ||
stats["neighbors"] = neighbors | ||
|
||
self.push(neighbors) | ||
|
||
self.receive_pushed_and_aggregate() | ||
|
||
stats["bytes_received"], stats["bytes_sent"] = self.comm_utils.get_comm_cost() | ||
|
||
# evaluate the model on the test data | ||
# Inside FedStaticNode.run_protocol() | ||
stats["test_loss"], stats["test_acc"] = self.local_test() | ||
self.log_metrics(stats=stats, iteration=it) | ||
self.local_round_done() | ||
|
||
|
||
|
||
class SwiftServer(FedStaticServer): | ||
""" | ||
Swift Server Class. It does not do anything. | ||
It just exists to keep the code compatible across different algorithms. | ||
""" | ||
def __init__( | ||
self, config: Dict[str, Any], comm_utils: CommunicationManager | ||
) -> None: | ||
pass | ||
|
||
def run_protocol(self) -> None: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.