Skip to content

Commit

Permalink
Add support for push-based algos (#126)
Browse files Browse the repository at this point in the history
* Add support for pushing models

* Fix no messages for a round and add generator for node sampling
  • Loading branch information
rishi-s8 authored Oct 28, 2024
1 parent eb8c4b7 commit 277d1a9
Show file tree
Hide file tree
Showing 12 changed files with 376 additions and 36 deletions.
49 changes: 46 additions & 3 deletions src/algos/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def set_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)


class BaseNode(ABC):
"""BaseNode is an abstract base class that provides foundational functionalities for nodes in a distributed system. It handles configuration, logging, CUDA setup, model parameter settings, and shared experiment parameters.
Expand Down Expand Up @@ -123,6 +122,7 @@ def set_constants(self) -> None:
"""Add docstring here"""
self.best_acc = 0.0
self.round = 0
self.EMPTY_MODEL_TAG = "EMPTY_MODEL"

def setup_logging(self, config: Dict[str, ConfigType]) -> None:
"""
Expand Down Expand Up @@ -319,13 +319,13 @@ def strip_empty_models(self, models_wts: List[OrderedDict[str, Any]],
if collab_weights is not None:
weight_list = []
for i, model_wts in enumerate(models_wts):
if len(model_wts) > 0 and collab_weights[i] > 0:
if self.EMPTY_MODEL_TAG not in model_wts and collab_weights[i] > 0:
repr_list.append(model_wts)
weight_list.append(collab_weights[i])
return repr_list, weight_list
else:
for model_wts in models_wts:
if len(model_wts) > 0:
if self.EMPTY_MODEL_TAG not in model_wts:
repr_list.append(model_wts)
return repr_list, None

Expand Down Expand Up @@ -358,6 +358,15 @@ def set_model_weights(

self.model.load_state_dict(model_wts, strict=len(keys_to_ignore) == 0)

def push(self, neighbors: List[int]) -> None:
"""
Pushes the model to the neighbors.
"""

data_to_send = self.get_model_weights()

self.comm_utils.send(neighbors, data_to_send)

class BaseClient(BaseNode):
"""
Abstract class for all algorithms
Expand Down Expand Up @@ -619,6 +628,26 @@ def receive_and_aggregate(self):
assert "model" in repr, "Model not found in the received message"
self.set_model_weights(repr["model"])

def receive_pushed_and_aggregate(self, remove_multi = True) -> None:
model_updates = self.comm_utils.receive_pushed()

if len(model_updates) > 0:
if self.is_working:
# Remove multiple models of different rounds from each node
if remove_multi:
to_aggregate = {}
for model in model_updates:
sender = model.get("sender", 0)
if sender not in to_aggregate or to_aggregate[sender].get("round", 0) < model.get("round", 0):
to_aggregate[sender] = model
model_updates = list(to_aggregate.values())
# Aggregate the representations
repr = model_updates[0]
assert "model" in repr, "Model not found in the received message"
self.set_model_weights(repr["model"])
else:
print("No one pushed model updates for this round.")

def run_protocol(self) -> None:
raise NotImplementedError

Expand Down Expand Up @@ -808,6 +837,20 @@ def aggregate(
self.set_model_weights(agg_wts)
return None

def receive_pushed_and_aggregate(self, remove_multi = True) -> None:
model_updates = self.comm_utils.receive_pushed()
if self.is_working:
# Remove multiple models of different rounds from each node
if remove_multi:
to_aggregate = {}
for model in model_updates:
sender = model.get("sender", 0)
if sender not in to_aggregate or to_aggregate[sender].get("round", 0) < model.get("round", 0):
to_aggregate[sender] = model
model_updates = list(to_aggregate.values())
# Aggregate the representations
self.aggregate(model_updates, keys_to_ignore=self.model_keys_to_ignore)

def receive_and_aggregate(self, neighbors: List[int]) -> None:
if self.is_working:
# Receive the model updates from the neighbors
Expand Down
72 changes: 72 additions & 0 deletions src/algos/fl_push.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import random
from collections import OrderedDict
from typing import Any, Dict, List
from torch import Tensor
from utils.communication.comm_utils import CommunicationManager
from algos.fl import FedAvgClient, FedAvgServer
import time

# import the possible attacks
from algos.attack_add_noise import AddNoiseAttack
from algos.attack_bad_weights import BadWeightsAttack
from algos.attack_sign_flip import SignFlipAttack

class FedAvgPushClient(FedAvgClient):
def __init__(
self, config: Dict[str, Any], comm_utils: CommunicationManager
) -> None:
super().__init__(config, comm_utils)

def run_protocol(self):
stats: Dict[str, Any] = {}
print(f"Client {self.node_id} ready to start training")

start_rounds = self.config.get("start_rounds", 0)
total_rounds = self.config["rounds"]

for round in range(start_rounds, total_rounds):
# Fetch model from the server
self.receive_pushed_and_aggregate()

stats["train_loss"], stats["train_acc"], stats["train_time"] = self.local_train(round)
stats["test_loss"], stats["test_acc"], stats["test_time"] = self.local_test()

# Send the model to the server
self.push(self.server_node)

stats["bytes_received"], stats["bytes_sent"] = self.comm_utils.get_comm_cost()

self.log_metrics(stats=stats, iteration=round)

self.local_round_done()


class FedAvgPushServer(FedAvgServer):
def __init__(
self, config: Dict[str, Any], comm_utils: CommunicationManager
) -> None:
super().__init__(config, comm_utils)

def single_round(self):
"""
Runs the whole training procedure
"""
self.push(self.users)
self.receive_pushed_and_aggregate()

def receive_pushed_and_aggregate(self):
reprs = self.comm_utils.all_gather_pushed()
avg_wts = self.aggregate(reprs)
self.set_representation(avg_wts)

def run_protocol(self):
stats: Dict[str, Any] = {}
print(f"Client {self.node_id} ready to start training")
start_rounds = self.config.get("start_rounds", 0)
total_rounds = self.config["rounds"]
for round in range(start_rounds, total_rounds):
self.single_round()
stats["bytes_received"], stats["bytes_sent"] = self.comm_utils.get_comm_cost()
stats["test_loss"], stats["test_acc"], stats["test_time"] = self.test()
self.log_metrics(stats=stats, iteration=round)
self.local_round_done()
72 changes: 72 additions & 0 deletions src/algos/swift.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""
Module for FedStaticClient and FedStaticServer in Federated Learning.
"""
from typing import Any, Dict, OrderedDict
from utils.communication.comm_utils import CommunicationManager
import torch
import time

from algos.fl_static import FedStaticNode, FedStaticServer
from algos.topologies.collections import select_topology


class SwiftNode(FedStaticNode):
"""
Federated Static Client Class.
"""

def __init__(
self, config: Dict[str, Any], comm_utils: CommunicationManager
) -> None:
super().__init__(config, comm_utils)

def run_protocol(self) -> None:
"""
Runs the federated learning protocol for the client.
"""
stats: Dict[str, Any] = {}
print(f"Client {self.node_id} ready to start training")
start_round = self.config.get("start_round", 0)
if start_round != 0:
raise NotImplementedError(
"Start round different from 0 not implemented yet"
)
total_rounds = self.config["rounds"]
epochs_per_round = self.config.get("epochs_per_round", 1)
for it in range(start_round, total_rounds):
# Train locally and send the representation to the server
stats["train_loss"], stats["train_acc"], stats["train_time"] = self.local_train(
it, epochs_per_round
)

# Collect the representations from all other nodes from the server
neighbors = self.topology.sample_neighbours(self.num_collaborators)
# TODO: Log the neighbors
stats["neighbors"] = neighbors

self.push(neighbors)

self.receive_pushed_and_aggregate()

stats["bytes_received"], stats["bytes_sent"] = self.comm_utils.get_comm_cost()

# evaluate the model on the test data
# Inside FedStaticNode.run_protocol()
stats["test_loss"], stats["test_acc"] = self.local_test()
self.log_metrics(stats=stats, iteration=it)
self.local_round_done()



class SwiftServer(FedStaticServer):
"""
Swift Server Class. It does not do anything.
It just exists to keep the code compatible across different algorithms.
"""
def __init__(
self, config: Dict[str, Any], comm_utils: CommunicationManager
) -> None:
pass

def run_protocol(self) -> None:
pass
3 changes: 2 additions & 1 deletion src/algos/topologies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(self, config: ConfigType, rank: int) -> None:
self.rank = rank
self.num_users: int = self.config["num_users"] # type: ignore
self.graph: nx.Graph | None = None
self.neighbor_sample_generator = np.random.default_rng(seed=int(self.config["seed"])*10000 + self.rank ) # type: ignore

@abstractmethod
def generate_graph(self) -> None:
Expand Down Expand Up @@ -63,7 +64,7 @@ def sample_neighbours(self, k: int) -> List[int]:
neighbours = self.get_all_neighbours()
if len(neighbours) <= k:
return neighbours
return np.random.choice(neighbours, size=k, replace=False).tolist()
return self.neighbor_sample_generator.choice(neighbours, size=k, replace=False).tolist()

def get_neighbourhood_size(self) -> int:
"""
Expand Down
23 changes: 23 additions & 0 deletions src/configs/algo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,29 @@ def get_malicious_types(malicious_config_list: List[ConfigType]) -> Dict[str, st
"batch_size": 256,
}

swift: ConfigType = {
# Collaboration setup
"algo": "swift",
"topology": {"name": "watts_strogatz", "k": 3, "p": 0.2}, # type: ignore
"rounds": 20,

# Model parameters
"model": "resnet10",
"model_lr": 3e-4,
"batch_size": 256,
}

fedavgpush: ConfigType = {
# Collaboration setup
"algo": "fedavgpush",
"rounds": 2,

# Model parameters
"model": "resnet10",
"model_lr": 3e-4,
"batch_size": 256,
}

metaL2C_cifar10: ConfigType = {
"algo": "metal2c",
"sharing": "weights", # "updates"
Expand Down
18 changes: 9 additions & 9 deletions src/configs/sys_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
malicious_algo_config_list,
default_config_list,
fedstatic,
traditional_fl
traditional_fl,
swift,
fedavgpush
)

sliding_window_8c_4cpc_support = {
Expand Down Expand Up @@ -277,21 +279,17 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE):
"seed": 1,
"num_collaborators": NUM_COLLABORATORS,
"load_existing": False,
"device_ids": get_device_ids(num_users=swarm_users, gpus_available=[1, 2]),
# "algo": get_algo_configs(num_users=swarm_users, algo_configs=default_config_list), # type: ignore
"algos": get_algo_configs(
num_users=swarm_users,
algo_configs=default_config_list,
), # type: ignore
"dump_dir": DUMP_DIR,
"device_ids": get_device_ids(num_users=swarm_users, gpus_available=[3, 4]),
"algo": get_algo_configs(num_users=swarm_users, algo_configs=default_config_list), # type: ignore
# Dataset params
"dset": get_domainnet_support(
swarm_users
), # get_camelyon17_support(fedcentral_client), #get_domainnet_support(fedcentral_client),
"dpath": domainnet_dpath, # wilds_dpath,#domainnet_dpath,
"train_label_distribution": "iid", # Either "iid", "shard" "support",
"test_label_distribution": "iid", # Either "iid" "support",
"samples_per_user": 500,
"samples_per_user": 32,
"test_samples_per_class": 100,
"community_type": "dataset",
"exp_keys": [],
Expand Down Expand Up @@ -333,6 +331,8 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE):
for i in range(1, num_users + 1):
dropout_dicts[f"node_{i}"] = dropout_dict

# for swift or fedavgpush, just modify the algo_configs list
# for swift, synchronous should preferable be False
gpu_ids = [2, 3, 5, 6]
grpc_system_config: ConfigType = {
"exp_id": "static",
Expand All @@ -345,7 +345,7 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE):
"seed": 2,
"device_ids": get_device_ids(num_users, gpu_ids),
# "algos": get_algo_configs(num_users=num_users, algo_configs=default_config_list), # type: ignore
"algos": get_algo_configs(num_users=num_users, algo_configs=[traditional_fl]), # type: ignore
"algos": get_algo_configs(num_users=num_users, algo_configs=[fedstatic]), # type: ignore
"samples_per_user": 50000 // num_users, # distributed equally
"train_label_distribution": "iid",
"test_label_distribution": "iid",
Expand Down
4 changes: 4 additions & 0 deletions src/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from algos.fl_weight import FedWeightClient, FedWeightServer
from algos.fl_static import FedStaticNode, FedStaticServer
from algos.swarm import SWARMClient, SWARMServer
from algos.swift import SwiftNode, SwiftServer
from algos.DisPFL import DisPFLClient, DisPFLServer
from algos.def_kt import DefKTClient, DefKTServer
from algos.fedfomo import FedFomoClient, FedFomoServer
Expand All @@ -28,6 +29,7 @@
from algos.fl_central import CentralizedCLient, CentralizedServer
from algos.fl_data_repr import FedDataRepClient, FedDataRepServer
from algos.fl_val import FedValClient, FedValServer
from algos.fl_push import FedAvgPushClient, FedAvgPushServer

from utils.communication.comm_utils import CommunicationManager
from utils.config_utils import load_config, process_config
Expand All @@ -50,6 +52,8 @@
"centralized": [CentralizedServer, CentralizedCLient],
"feddatarepr": [FedDataRepServer, FedDataRepClient],
"fedval": [FedValServer, FedValClient],
"swift": [SwiftServer, SwiftNode],
"fedavgpush": [FedAvgPushServer, FedAvgPushClient],
}


Expand Down
6 changes: 6 additions & 0 deletions src/utils/communication/comm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,9 @@ def set_is_working(self, is_working: bool):

def get_comm_cost(self):
return self.comm.get_comm_cost()

def receive_pushed(self):
return self.comm.receive_pushed()

def all_gather_pushed(self):
return self.comm.all_gather_pushed()
1 change: 1 addition & 0 deletions src/utils/communication/grpc/comm.proto
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ syntax = "proto3";

service CommunicationServer {
rpc send_data (Data) returns (Empty) {}
rpc send_model (Model) returns (Empty) {}
rpc get_rank (Empty) returns (Rank) {}
rpc get_model (Empty) returns (Model) {}
rpc get_current_round (Empty) returns (Round) {}
Expand Down
Loading

0 comments on commit 277d1a9

Please sign in to comment.