diff --git a/.vscode/settings.json b/.vscode/settings.json
new file mode 100644
index 0000000..9e6483a
--- /dev/null
+++ b/.vscode/settings.json
@@ -0,0 +1,3 @@
+{
+ "python.analysis.typeCheckingMode": "strict"
+}
\ No newline at end of file
diff --git a/src/README.md b/src/README.md
index 0e66197..49136fe 100644
--- a/src/README.md
+++ b/src/README.md
@@ -31,3 +31,8 @@ Client Metrics
Server Metrics
+
+### Debugging instructions
+GRPC simulation starts a lot of threads and even if one of them fail right now then you will have to kill all of them and start all over.
+So, here is a command to get the pid of all the threads and kill them all at once:
+`for pid in $(ps aux|grep 'python main.py -r' | cut -b 10-16); do kill -9 $pid; done`
\ No newline at end of file
diff --git a/src/algos/base_class.py b/src/algos/base_class.py
index ba83a71..049d8c9 100644
--- a/src/algos/base_class.py
+++ b/src/algos/base_class.py
@@ -4,14 +4,14 @@
from torch.utils.data import DataLoader, Subset
from collections import OrderedDict
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional, Tuple
from torch import Tensor
import copy
import random
import numpy as np
+from utils.communication.comm_utils import CommunicationManager
from utils.plot_utils import PlotUtils
-from utils.comm_utils import CommUtils
from utils.data_utils import (
random_samples,
filter_by_class,
@@ -28,15 +28,15 @@
get_dset_balanced_communities,
get_dset_communities,
)
-import torchvision.transforms as T
+import torchvision.transforms as T # type: ignore
import os
from yolo import YOLOLoss
class BaseNode(ABC):
- def __init__(self, config) -> None:
- self.comm_utils = CommUtils()
- self.node_id = self.comm_utils.rank
+ def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None:
+ self.comm_utils = comm_utils
+ self.node_id = self.comm_utils.get_rank()
if self.node_id == 0:
self.log_dir = config['log_path']
@@ -54,13 +54,13 @@ def __init__(self, config) -> None:
if isinstance(config["dset"], dict):
if self.node_id != 0:
config["dset"].pop("0")
- self.dset = config["dset"][str(self.node_id)]
+ self.dset = str(config["dset"][str(self.node_id)])
config["dpath"] = config["dpath"][self.dset]
else:
self.dset = config["dset"]
self.setup_cuda(config)
- self.model_utils = ModelUtils()
+ self.model_utils = ModelUtils(self.device)
self.dset_obj = get_dataset(self.dset, dpath=config["dpath"])
self.set_constants()
@@ -68,7 +68,7 @@ def __init__(self, config) -> None:
def set_constants(self):
self.best_acc = 0.0
- def setup_cuda(self, config):
+ def setup_cuda(self, config: Dict[str, Any]):
# Need a mapping from rank to device id
device_ids_map = config["device_ids"]
node_name = "node_{}".format(self.node_id)
@@ -82,7 +82,7 @@ def setup_cuda(self, config):
self.device = torch.device("cpu")
print("Using CPU")
- def set_model_parameters(self, config):
+ def set_model_parameters(self, config: Dict[str, Any]):
# Model related parameters
optim_name = config.get("optimizer", "adam")
if optim_name == "adam":
@@ -149,7 +149,7 @@ def set_shared_exp_parameters(self, config):
self.log_utils.log_console("Communities: {}".format(self.communities))
@abstractmethod
- def run_protocol(self):
+ def run_protocol(self) -> None:
raise NotImplementedError
@@ -158,8 +158,8 @@ class BaseClient(BaseNode):
Abstract class for all algorithms
"""
- def __init__(self, config) -> None:
- super().__init__(config)
+ def __init__(self, config, comm_utils) -> None:
+ super().__init__(config, comm_utils)
self.server_node = 0
self.set_parameters(config)
@@ -215,8 +215,8 @@ def set_data_parameters(self, config):
train_dset = self.dset_obj.train_dset
test_dset = self.dset_obj.test_dset
- print("num train", len(train_dset))
- print("num test", len(test_dset))
+ # print("num train", len(train_dset))
+ # print("num test", len(test_dset))
if config.get("test_samples_per_class", None) is not None:
test_dset, _ = balanced_subset(test_dset, config["test_samples_per_class"])
@@ -369,19 +369,19 @@ def is_same_dest(dset):
# TODO: fix print_data_summary
# self.print_data_summary(train_dset, test_dset, val_dset=val_dset)
- def local_train(self, dataset, **kwargs):
+ def local_train(self, round: int, **kwargs: Any) -> None:
"""
Train the model locally
"""
raise NotImplementedError
- def local_test(self, dataset, **kwargs):
+ def local_test(self, **kwargs: Any) -> float | Tuple[float, float] | None:
"""
Test the model locally
"""
raise NotImplementedError
- def get_representation(self, **kwargs):
+ def get_representation(self, **kwargs: Any) -> OrderedDict[str, Tensor] | List[Tensor] | Tensor:
"""
Share the model representation
"""
@@ -416,21 +416,17 @@ def print_data_summary(self, train_test, test_dset, val_dset=None):
print("test count: ", i)
i += 1
- print("Node: {} data distribution summary".format(self.node_id))
- print(type(train_sample_per_class.items()))
- print(
- "Train samples per class: {}".format(sorted(train_sample_per_class.items()))
- )
- print(
- "Train samples per class: {}".format(len(train_sample_per_class.items()))
- )
- if val_dset is not None:
- print(
- "Val samples per class: {}".format(len(val_sample_per_class.items()))
- )
- print(
- "Test samples per class: {}".format(len(test_sample_per_class.items()))
- )
+ # print("Node: {} data distribution summary".format(self.node_id))
+ # print(
+ # "Train samples per class: {}".format(sorted(train_sample_per_class.items()))
+ # )
+ # if val_dset is not None:
+ # print(
+ # "Val samples per class: {}".format(sorted(val_sample_per_class.items()))
+ # )
+ # print(
+ # "Test samples per class: {}".format(sorted(test_sample_per_class.items()))
+ # )
class BaseServer(BaseNode):
@@ -438,8 +434,8 @@ class BaseServer(BaseNode):
Abstract class for orchestrator
"""
- def __init__(self, config) -> None:
- super().__init__(config)
+ def __init__(self, config, comm_utils) -> None:
+ super().__init__(config, comm_utils)
self.num_users = config["num_users"]
self.users = list(range(1, self.num_users + 1))
self.set_data_parameters(config)
@@ -449,13 +445,13 @@ def set_data_parameters(self, config):
batch_size = config["batch_size"]
self._test_loader = DataLoader(test_dset, batch_size=batch_size)
- def aggregate(self, representation_list, **kwargs):
+ def aggregate(self, representation_list: List[OrderedDict[str, Tensor]], **kwargs: Any) -> OrderedDict[str, Tensor]:
"""
Aggregate the knowledge from the users
"""
raise NotImplementedError
- def test(self, dataset, **kwargs):
+ def test(self, **kwargs: Any) -> List[float]:
"""
Test the model on the server
"""
@@ -668,10 +664,5 @@ def __init__(self, config, comm_protocol=CommProtocol) -> None:
super().__init__(config)
self.tag = comm_protocol
- def send_representations(self, representations, tag=None):
- for user_node in self.users:
- self.comm_utils.send_signal(
- dest=user_node,
- data=representations,
- tag=self.tag.REPRS_SHARE if tag is None else tag,
- )
+ def send_representations(self, representations: Dict[int, OrderedDict[str, Tensor]]):
+ self.comm_utils.broadcast(representations)
diff --git a/src/algos/fl.py b/src/algos/fl.py
index 781c417..94eb368 100644
--- a/src/algos/fl.py
+++ b/src/algos/fl.py
@@ -1,48 +1,32 @@
from collections import OrderedDict
+import sys
from typing import Any, Dict, List
from torch import Tensor
-import torch.nn as nn
+from utils.communication.comm_utils import CommunicationManager
from utils.log_utils import LogUtils
from algos.base_class import BaseClient, BaseServer
import os
import time
-class CommProtocol(object):
- """
- Communication protocol tags for the server and clients
- """
-
- DONE = 0 # Used to signal that the client is done with the current round
- START = 1 # Used to signal by the server to start the current round
- UPDATES = 2 # Used to send the updates from the server to the clients
-
-
class FedAvgClient(BaseClient):
- def __init__(self, config) -> None:
- super().__init__(config)
+ def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None:
+ super().__init__(config, comm_utils)
self.config = config
- self.tag = CommProtocol
- self.folder_deletion_signal = config["folder_deletion_signal_path"]
- while not os.path.exists(self.folder_deletion_signal):
- print("Existing experiment already present, waiting user input, enter 'r' or 'e'...")
- time.sleep(5)
-
- # Once the signal file exists, read its contents
- with open(self.folder_deletion_signal, "r") as signal_file:
- mode = signal_file.read().strip()
-
- if mode == 'r' or mode == 'new':
- try:
- config['log_path'] = f"{config['log_path']}/client_{self.node_id}"
- os.makedirs(config['log_path'], exist_ok=True)
- except FileExistsError:
- pass
+ try:
+ config['log_path'] = f"{config['log_path']}/node_{self.node_id}"
+ os.makedirs(config['log_path'])
+ except FileExistsError:
+ color_code = "\033[91m" # Red color
+ reset_code = "\033[0m" # Reset to default color
+ print(f"{color_code}Log directory for the node {self.node_id} already exists in {config['log_path']}")
+ print(f"Exiting to prevent accidental overwrite{reset_code}")
+ sys.exit(1)
config['load_existing'] = False
self.client_log_utils = LogUtils(config)
- def local_train(self, round):
+ def local_train(self, round: int, **kwargs: Any):
"""
Train the model locally
"""
@@ -61,17 +45,17 @@ def local_train(self, round):
self.client_log_utils.log_tb(f"train_loss/client{self.node_id}", avg_loss, round)
self.client_log_utils.log_tb(f"train_accuracy/client{self.node_id}", avg_accuracy, round)
- def local_test(self, **kwargs):
+ def local_test(self, **kwargs: Any):
"""
Test the model locally, not to be used in the traditional FedAvg
"""
pass
- def get_representation(self) -> OrderedDict[str, Tensor]:
+ def get_representation(self, **kwargs: Any) -> OrderedDict[str, Tensor]:
"""
Share the model weights
"""
- return self.model.state_dict()
+ return self.model.state_dict() # type: ignore
def set_representation(self, representation: OrderedDict[str, Tensor]):
"""
@@ -83,71 +67,79 @@ def run_protocol(self):
start_epochs = self.config.get("start_epochs", 0)
total_epochs = self.config["epochs"]
for round in range(start_epochs, total_epochs):
- self.client_log_utils.log_summary("Client {} waiting for semaphore from {}".format(self.node_id, self.server_node))
- self.comm_utils.wait_for_signal(src=self.server_node, tag=self.tag.START)
- self.client_log_utils.log_summary("Client {} received semaphore from {}".format(self.node_id, self.server_node))
self.local_train(round)
self.local_test()
repr = self.get_representation()
self.client_log_utils.log_summary("Client {} sending done signal to {}".format(self.node_id, self.server_node))
- self.comm_utils.send_signal(
- dest=self.server_node, data=repr, tag=self.tag.DONE
- )
+ self.comm_utils.send(self.server_node, repr)
self.client_log_utils.log_summary("Client {} waiting to get new model from {}".format(self.node_id, self.server_node))
- repr = self.comm_utils.wait_for_signal(
- src=self.server_node, tag=self.tag.UPDATES
- )
+ repr = self.comm_utils.receive(self.server_node)
self.client_log_utils.log_summary("Client {} received new model from {}".format(self.node_id, self.server_node))
self.set_representation(repr)
- self.client_log_utils.log_summary("Round {} done for Client {}".format(round, self.node_id))
+ # self.client_log_utils.log_summary("Round {} done for Client {}".format(round, self.node_id))
class FedAvgServer(BaseServer):
- def __init__(self, config) -> None:
- super().__init__(config)
+ def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None:
+ super().__init__(config, comm_utils)
# self.set_parameters()
self.config = config
self.set_model_parameters(config)
- self.tag = CommProtocol
self.model_save_path = "{}/saved_models/node_{}.pt".format(
self.config["results_path"], self.node_id
)
self.folder_deletion_signal = config["folder_deletion_signal_path"]
+ # def fed_avg(self, model_wts: List[OrderedDict[str, Tensor]]):
+ # # All models are sampled currently at every round
+ # # Each model is assumed to have equal amount of data and hence
+ # # coeff is same for everyone
+ # num_users = len(model_wts)
+ # coeff = 1 / num_users # this assumes each node has equal amount of data
+ # avgd_wts: OrderedDict[str, Tensor] = OrderedDict()
+ # first_model = model_wts[0]
+
+ # for node_num in range(num_users):
+ # local_wts = model_wts[node_num]
+ # for key in first_model.keys():
+ # if node_num == 0:
+ # avgd_wts[key] = coeff * local_wts[key].to('cpu')
+ # else:
+ # avgd_wts[key] += coeff * local_wts[key].to('cpu')
+ # # put the model back to the device
+ # for key in avgd_wts.keys():
+ # avgd_wts[key] = avgd_wts[key].to(self.device)
+ # return avgd_wts
+
def fed_avg(self, model_wts: List[OrderedDict[str, Tensor]]):
- # All models are sampled currently at every round
- # Each model is assumed to have equal amount of data and hence
- # coeff is same for everyone
num_users = len(model_wts)
coeff = 1 / num_users
- avgd_wts = OrderedDict()
- first_model = model_wts[0]
-
- for client_num in range(num_users):
- local_wts = model_wts[client_num]
- for key in first_model.keys():
- if client_num == 0:
- avgd_wts[key] = coeff * local_wts[key].to(self.device)
- else:
- avgd_wts[key] += coeff * local_wts[key].to(self.device)
+ avgd_wts: OrderedDict[str, Tensor] = OrderedDict()
+
+ for key in model_wts[0].keys():
+ avgd_wts[key] = sum(coeff * m[key] for m in model_wts) # type: ignore
+
+ # Move to GPU only after averaging
+ for key in avgd_wts.keys():
+ avgd_wts[key] = avgd_wts[key].to(self.device)
return avgd_wts
- def aggregate(self, representation_list: List[OrderedDict[str, Tensor]]):
+ def aggregate(self, representation_list: List[OrderedDict[str, Tensor]], **kwargs: Any) -> OrderedDict[str, Tensor]:
"""
Aggregate the model weights
"""
avg_wts = self.fed_avg(representation_list)
return avg_wts
- def set_representation(self, representation):
+ def set_representation(self, representation: OrderedDict[str, Tensor]):
"""
Set the model
"""
- for client_node in self.users:
- self.comm_utils.send_signal(client_node, representation, self.tag.UPDATES)
+ self.comm_utils.broadcast(representation)
+ print("braodcasted")
self.model.load_state_dict(representation)
- def test(self) -> float:
+ def test(self, **kwargs: Any) -> List[float]:
"""
Test the model on the server
"""
@@ -162,22 +154,16 @@ def test(self) -> float:
if test_acc > self.best_acc:
self.best_acc = test_acc
self.model_utils.save_model(self.model, self.model_save_path)
- return test_loss, test_acc, time_taken
+ return [test_loss, test_acc, time_taken]
def single_round(self):
"""
Runs the whole training procedure
"""
- for client_node in self.users:
- self.log_utils.log_console(
- "Server sending semaphore from {} to {}".format(
- self.node_id, client_node
- )
- )
- self.comm_utils.send_signal(dest=client_node, data=None, tag=self.tag.START)
- self.log_utils.log_console("Server waiting for all clients to finish")
- reprs = self.comm_utils.wait_for_all_clients(self.users, self.tag.DONE)
- self.log_utils.log_console("Server received all clients done signal")
+ # calculate how much memory torch is occupying right now
+ # self.log_utils.log_console("Server waiting for all clients to finish")
+ reprs = self.comm_utils.all_gather()
+ # self.log_utils.log_console("Server received all clients done signal")
avg_wts = self.aggregate(reprs)
self.set_representation(avg_wts)
#Remove the signal file after confirming that all client paths have been created
@@ -185,7 +171,7 @@ def single_round(self):
os.remove(self.folder_deletion_signal)
def run_protocol(self):
- self.log_utils.log_console("Starting iid clients federated averaging")
+ self.log_utils.log_console("Starting clients federated averaging")
start_epochs = self.config.get("start_epochs", 0)
total_epochs = self.config["epochs"]
for round in range(start_epochs, total_epochs):
@@ -197,6 +183,6 @@ def run_protocol(self):
self.log_utils.log_tb(f"test_acc/clients", acc, round)
self.log_utils.log_tb(f"test_loss/clients", loss, round)
self.log_utils.log_console("Round: {} test_acc:{:.4f}, test_loss:{:.4f}, time taken {:.2f} seconds".format(round, acc, loss, time_taken))
- self.log_utils.log_summary("Round: {} test_acc:{:.4f}, test_loss:{:.4f}, time taken {:.2f} seconds".format(round, acc, loss, time_taken))
+ # self.log_utils.log_summary("Round: {} test_acc:{:.4f}, test_loss:{:.4f}, time taken {:.2f} seconds".format(round, acc, loss, time_taken))
self.log_utils.log_console("Round {} complete".format(round))
self.log_utils.log_summary("Round {} complete".format(round,))
diff --git a/src/configs/algo_config.py b/src/configs/algo_config.py
index 9f1044e..95ff516 100644
--- a/src/configs/algo_config.py
+++ b/src/configs/algo_config.py
@@ -22,7 +22,7 @@
"exp_type": "iid_clients_federated",
# Learning setup
"epochs": 1000,
- "model": "resnet34",
+ "model": "resnet10",
"model_lr": 3e-4,
"batch_size": 256,
"exp_keys": [],
diff --git a/src/configs/sys_config.py b/src/configs/sys_config.py
index 884bd13..bfe99fd 100644
--- a/src/configs/sys_config.py
+++ b/src/configs/sys_config.py
@@ -1,9 +1,15 @@
# System Configuration
# TODO: Set up multiple non-iid configurations here. The goal of a separate system config
# is to simulate different real-world scenarios without changing the algorithm configuration.
-system_config = {
- "num_users": 3,
- "experiment_path": "./experiments/",
+from typing import Dict, List
+
+
+mpi_system_config = {
+ "comm": {
+ "type": "MPI"
+ },
+ "num_users": 4,
+ # "experiment_path": "./experiments/",
"dset": "cifar10",
"dump_dir": "./expt_dump/",
"dpath": "./datasets/imgs/cifar10/",
@@ -37,4 +43,35 @@
"folder_deletion_signal_path":"./expt_dump/folder_deletion.signal"
}
-current_config = object_detect_system_config
\ No newline at end of file
+def get_device_ids(num_users: int, gpus_available: List[int]) -> Dict[str, List[int]]:
+ """
+ Get the GPU device IDs for the users.
+ """
+ # TODO: Make it multi-host
+ device_ids: Dict[str, List[int]] = {}
+ for i in range(num_users):
+ index = i % len(gpus_available)
+ gpu_id = gpus_available[index]
+ device_ids[f"node_{i}"] = [gpu_id]
+ return device_ids
+
+num_users = 10
+gpu_ids = [0, 1, 2, 3, 4, 5, 6, 7]
+grpc_system_config = {
+ "num_users": num_users,
+ "comm": {
+ "type": "GRPC",
+ "peer_ids": ["localhost:50050"] # The super-node
+ },
+ "dset": "cifar10",
+ "dump_dir": "./expt_dump/",
+ "dpath": "./datasets/imgs/cifar10/",
+ "seed": 2,
+ "device_ids": get_device_ids(num_users + 1, gpu_ids), # +1 for the super-node
+ "samples_per_user": 500,
+ "train_label_distribution": "iid",
+ "test_label_distribution": "iid",
+ "folder_deletion_signal_path":"./expt_dump/folder_deletion.signal"
+}
+
+current_config = grpc_system_config
diff --git a/src/main.py b/src/main.py
index f4952fc..4880f22 100644
--- a/src/main.py
+++ b/src/main.py
@@ -11,6 +11,7 @@
B_DEFAULT = "./configs/algo_config.py"
S_DEFAULT = "./configs/sys_config.py"
+RANK_DEFAULT = None
parser = argparse.ArgumentParser(description="Run collaborative learning experiments")
parser.add_argument(
@@ -27,11 +28,18 @@
type=str,
help=f"filepath for system config, default: {S_DEFAULT}",
)
+parser.add_argument(
+ "-r",
+ nargs="?",
+ default=RANK_DEFAULT,
+ type=int,
+ help=f"rank of the node, default: {RANK_DEFAULT}",
+)
args = parser.parse_args()
scheduler = Scheduler()
-scheduler.assign_config_by_path(args.s, args.b)
+scheduler.assign_config_by_path(args.s, args.b, args.r)
print("Config loaded")
scheduler.install_config()
diff --git a/src/main_grpc.py b/src/main_grpc.py
new file mode 100644
index 0000000..546e6c0
--- /dev/null
+++ b/src/main_grpc.py
@@ -0,0 +1,37 @@
+"""
+This module runs collaborative learning experiments using the Scheduler class.
+"""
+
+import argparse
+import logging
+import subprocess
+
+from utils.config_utils import load_config
+
+logging.getLogger("PIL").setLevel(logging.INFO)
+
+S_DEFAULT = "./configs/sys_config.py"
+RANK_DEFAULT = 0
+
+parser = argparse.ArgumentParser(description="Run collaborative learning experiments")
+parser.add_argument(
+ "-s",
+ nargs="?",
+ default=S_DEFAULT,
+ type=str,
+ help=f"filepath for system config, default: {S_DEFAULT}",
+)
+
+args = parser.parse_args()
+
+sys_config = load_config(args.s)
+print("Sys config loaded")
+
+# 1. find the number of users in the system configuration
+# 2. start separate processes by running python main.py for each user
+
+num_users = sys_config["num_users"] + 1 # +1 for the super-node
+for i in range(num_users):
+ print(f"Starting process for user {i}")
+ # start a Popen process
+ subprocess.Popen(["python", "main.py", "-r", str(i)])
\ No newline at end of file
diff --git a/src/resnet_in.py b/src/resnet_in.py
index b160e3e..dca143b 100644
--- a/src/resnet_in.py
+++ b/src/resnet_in.py
@@ -5,6 +5,7 @@
This module implements ResNet models for ImageNet classification.
"""
+from typing import Any
from torch import nn
from torch.hub import load_state_dict_from_url
@@ -297,7 +298,7 @@ def _resnet(
return model
-def resnet18(pretrained=False, progress=True, **kwargs):
+def resnet18(pretrained:bool=False, progress:bool=True, **kwargs: Any):
r"""ResNet-18 model from
`"Deep Residual Learning for Image Recognition" `_
Args:
@@ -309,7 +310,7 @@ def resnet18(pretrained=False, progress=True, **kwargs):
-def resnet34(pretrained=False, progress=True, **kwargs):
+def resnet34(pretrained:bool=False, progress: bool=True, **kwargs: Any):
r"""ResNet-34 model from
`"Deep Residual Learning for Image Recognition" `_
Args:
diff --git a/src/scheduler.py b/src/scheduler.py
index c19e160..568a936 100644
--- a/src/scheduler.py
+++ b/src/scheduler.py
@@ -5,12 +5,10 @@
import os
import random
+from typing import Any, Dict
-from mpi4py import MPI
import torch
-import numpy as np
-
import random
import numpy
@@ -31,11 +29,9 @@
from algos.fl_data_repr import FedDataRepClient, FedDataRepServer
from algos.fl_val import FedValClient, FedValServer
-from utils.log_utils import check_and_create_path
+from utils.communication.comm_utils import CommunicationManager
from utils.config_utils import load_config, process_config
-
from utils.log_utils import copy_source_code, check_and_create_path
-from utils.config_utils import load_config, process_config, get_device_ids
import os
@@ -64,10 +60,10 @@
"fedval": [FedValServer, FedValClient],
}
-def get_node(config: dict, rank) -> BaseNode:
+def get_node(config: Dict[str, Any], rank: int, comm_utils: CommunicationManager) -> BaseNode:
algo_name = config["algo"]
- return algo_map[algo_name][rank > 0](config)
-
+ return algo_map[algo_name][rank > 0](config, comm_utils)
+
class Scheduler():
""" Manages the overall orchestration of experiments
"""
@@ -76,10 +72,12 @@ def __init__(self) -> None:
pass
def install_config(self) -> None:
- self.config = process_config(self.config)
+ self.config: Dict[str, Any] = process_config(self.config)
- def assign_config_by_path(self, sys_config_path, algo_config_path):
+ def assign_config_by_path(self, sys_config_path: Dict[str, Any], algo_config_path: Dict[str, Any], rank: int|None = None) -> None:
self.sys_config = load_config(sys_config_path)
+ if rank is not None:
+ self.sys_config["comm"]["rank"] = rank
self.algo_config = load_config(algo_config_path)
self.merge_configs()
@@ -88,19 +86,17 @@ def merge_configs(self):
self.config.update(self.sys_config)
self.config.update(self.algo_config)
- def initialize(self, copy_souce_code=True) -> None:
+ def initialize(self, copy_souce_code: bool=True) -> None:
assert self.config is not None, "Config should be set when initializing"
-
- comm = MPI.COMM_WORLD
- rank = comm.Get_rank()
+ self.communication = CommunicationManager(self.config)
# Base clients modify the seed later on
seed = self.config["seed"]
- torch.manual_seed(seed)
+ torch.manual_seed(seed) # type: ignore
random.seed(seed)
numpy.random.seed(seed)
- if rank == 0:
+ if self.communication.get_rank() == 0:
if copy_souce_code:
copy_source_code(self.config)
else:
@@ -109,7 +105,7 @@ def initialize(self, copy_souce_code=True) -> None:
os.mkdir(self.config["saved_models"])
os.mkdir(self.config["log_path"])
- self.node = get_node(self.config, rank=rank)
+ self.node = get_node(self.config, rank=self.communication.get_rank(), comm_utils=self.communication)
def run_job(self) -> None:
self.node.run_protocol()
diff --git a/src/utils/communication/comm_utils.py b/src/utils/communication/comm_utils.py
new file mode 100644
index 0000000..cb4a979
--- /dev/null
+++ b/src/utils/communication/comm_utils.py
@@ -0,0 +1,55 @@
+from enum import Enum
+from typing import Any, Dict
+
+from utils.communication.grpc.main import GRPCCommunication
+from utils.communication.mpi import MPICommUtils
+
+
+class CommunicationType(Enum):
+ MPI = 1
+ GRPC = 2
+ HTTP = 3
+
+
+class CommunicationFactory:
+ @staticmethod
+ def create_communication(config: Dict[str, Any], comm_type: CommunicationType) -> Any:
+ comm_type = comm_type
+ if comm_type == CommunicationType.MPI:
+ return MPICommUtils(config)
+ elif comm_type == CommunicationType.GRPC:
+ return GRPCCommunication(config)
+ elif comm_type == CommunicationType.HTTP:
+ raise NotImplementedError("HTTP communication not yet implemented")
+ else:
+ raise ValueError("Invalid communication type")
+
+
+class CommunicationManager:
+ def __init__(self, config: Dict[str, Any]):
+ self.comm_type = CommunicationType[config["comm"]["type"]]
+ self.comm = CommunicationFactory.create_communication(config, self.comm_type)
+ self.comm.initialize()
+
+ def get_rank(self):
+ if self.comm_type == CommunicationType.MPI:
+ return self.comm.rank
+ elif self.comm_type == CommunicationType.GRPC:
+ return self.comm.rank
+ else:
+ raise NotImplementedError("Rank not implemented for communication type", self.comm_type)
+
+ def send(self, dest:str|int, data:Any):
+ self.comm.send(dest, data)
+
+ def receive(self, node_ids: str|int) -> Any:
+ return self.comm.receive(node_ids)
+
+ def broadcast(self, data: Any):
+ self.comm.broadcast(data)
+
+ def all_gather(self):
+ return self.comm.all_gather()
+
+ def finalize(self):
+ self.comm.finalize()
diff --git a/src/utils/communication/grpc/comm.proto b/src/utils/communication/grpc/comm.proto
new file mode 100644
index 0000000..426b580
--- /dev/null
+++ b/src/utils/communication/grpc/comm.proto
@@ -0,0 +1,44 @@
+// To generate the gRPC code, run the following command:
+// python -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. comm.proto --pyi_out=.
+syntax = "proto3";
+
+service CommunicationServer {
+ rpc send_data (Data) returns (Empty) {}
+ rpc get_rank (Empty) returns (Rank) {}
+ rpc update_port (PeerId) returns (Empty) {}
+ rpc send_peer_ids (PeerIds) returns (Empty) {}
+ rpc send_quorum (Quorum) returns (Empty) {}
+}
+
+message Empty {}
+
+message Model {
+ bytes buffer = 1;
+}
+
+message Data {
+ string id = 1;
+ Model model = 2;
+}
+
+message Rank {
+ int32 rank = 1;
+}
+
+message Port {
+ int32 port = 1;
+}
+
+message PeerId {
+ Rank rank = 1;
+ Port port = 2;
+ string ip = 3;
+}
+
+message PeerIds {
+ map peer_ids = 1;
+}
+
+message Quorum {
+ bool quorum = 1;
+}
\ No newline at end of file
diff --git a/src/utils/communication/grpc/comm_pb2.py b/src/utils/communication/grpc/comm_pb2.py
new file mode 100644
index 0000000..5c1623f
--- /dev/null
+++ b/src/utils/communication/grpc/comm_pb2.py
@@ -0,0 +1,46 @@
+# -*- coding: utf-8 -*-
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# source: comm.proto
+# Protobuf Python Version: 5.26.1
+"""Generated protocol buffer code."""
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import descriptor_pool as _descriptor_pool
+from google.protobuf import symbol_database as _symbol_database
+from google.protobuf.internal import builder as _builder
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+
+
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\ncomm.proto\"\x07\n\x05\x45mpty\"\x17\n\x05Model\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\")\n\x04\x44\x61ta\x12\n\n\x02id\x18\x01 \x01(\t\x12\x15\n\x05model\x18\x02 \x01(\x0b\x32\x06.Model\"\x14\n\x04Rank\x12\x0c\n\x04rank\x18\x01 \x01(\x05\"\x14\n\x04Port\x12\x0c\n\x04port\x18\x01 \x01(\x05\">\n\x06PeerId\x12\x13\n\x04rank\x18\x01 \x01(\x0b\x32\x05.Rank\x12\x13\n\x04port\x18\x02 \x01(\x0b\x32\x05.Port\x12\n\n\x02ip\x18\x03 \x01(\t\"k\n\x07PeerIds\x12\'\n\x08peer_ids\x18\x01 \x03(\x0b\x32\x15.PeerIds.PeerIdsEntry\x1a\x37\n\x0cPeerIdsEntry\x12\x0b\n\x03key\x18\x01 \x01(\x05\x12\x16\n\x05value\x18\x02 \x01(\x0b\x32\x07.PeerId:\x02\x38\x01\"\x18\n\x06Quorum\x12\x0e\n\x06quorum\x18\x01 \x01(\x08\x32\xb9\x01\n\x13\x43ommunicationServer\x12\x1c\n\tsend_data\x12\x05.Data\x1a\x06.Empty\"\x00\x12\x1b\n\x08get_rank\x12\x06.Empty\x1a\x05.Rank\"\x00\x12 \n\x0bupdate_port\x12\x07.PeerId\x1a\x06.Empty\"\x00\x12#\n\rsend_peer_ids\x12\x08.PeerIds\x1a\x06.Empty\"\x00\x12 \n\x0bsend_quorum\x12\x07.Quorum\x1a\x06.Empty\"\x00\x62\x06proto3')
+
+_globals = globals()
+_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
+_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'comm_pb2', _globals)
+if not _descriptor._USE_C_DESCRIPTORS:
+ DESCRIPTOR._loaded_options = None
+ _globals['_PEERIDS_PEERIDSENTRY']._loaded_options = None
+ _globals['_PEERIDS_PEERIDSENTRY']._serialized_options = b'8\001'
+ _globals['_EMPTY']._serialized_start=14
+ _globals['_EMPTY']._serialized_end=21
+ _globals['_MODEL']._serialized_start=23
+ _globals['_MODEL']._serialized_end=46
+ _globals['_DATA']._serialized_start=48
+ _globals['_DATA']._serialized_end=89
+ _globals['_RANK']._serialized_start=91
+ _globals['_RANK']._serialized_end=111
+ _globals['_PORT']._serialized_start=113
+ _globals['_PORT']._serialized_end=133
+ _globals['_PEERID']._serialized_start=135
+ _globals['_PEERID']._serialized_end=197
+ _globals['_PEERIDS']._serialized_start=199
+ _globals['_PEERIDS']._serialized_end=306
+ _globals['_PEERIDS_PEERIDSENTRY']._serialized_start=251
+ _globals['_PEERIDS_PEERIDSENTRY']._serialized_end=306
+ _globals['_QUORUM']._serialized_start=308
+ _globals['_QUORUM']._serialized_end=332
+ _globals['_COMMUNICATIONSERVER']._serialized_start=335
+ _globals['_COMMUNICATIONSERVER']._serialized_end=520
+# @@protoc_insertion_point(module_scope)
diff --git a/src/utils/communication/grpc/comm_pb2.pyi b/src/utils/communication/grpc/comm_pb2.pyi
new file mode 100644
index 0000000..e81bc4e
--- /dev/null
+++ b/src/utils/communication/grpc/comm_pb2.pyi
@@ -0,0 +1,65 @@
+from google.protobuf.internal import containers as _containers
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import message as _message
+from typing import ClassVar as _ClassVar, Mapping as _Mapping, Optional as _Optional, Union as _Union
+
+DESCRIPTOR: _descriptor.FileDescriptor
+
+class Empty(_message.Message):
+ __slots__ = ()
+ def __init__(self) -> None: ...
+
+class Model(_message.Message):
+ __slots__ = ("buffer",)
+ BUFFER_FIELD_NUMBER: _ClassVar[int]
+ buffer: bytes
+ def __init__(self, buffer: _Optional[bytes] = ...) -> None: ...
+
+class Data(_message.Message):
+ __slots__ = ("id", "model")
+ ID_FIELD_NUMBER: _ClassVar[int]
+ MODEL_FIELD_NUMBER: _ClassVar[int]
+ id: str
+ model: Model
+ def __init__(self, id: _Optional[str] = ..., model: _Optional[_Union[Model, _Mapping]] = ...) -> None: ...
+
+class Rank(_message.Message):
+ __slots__ = ("rank",)
+ RANK_FIELD_NUMBER: _ClassVar[int]
+ rank: int
+ def __init__(self, rank: _Optional[int] = ...) -> None: ...
+
+class Port(_message.Message):
+ __slots__ = ("port",)
+ PORT_FIELD_NUMBER: _ClassVar[int]
+ port: int
+ def __init__(self, port: _Optional[int] = ...) -> None: ...
+
+class PeerId(_message.Message):
+ __slots__ = ("rank", "port", "ip")
+ RANK_FIELD_NUMBER: _ClassVar[int]
+ PORT_FIELD_NUMBER: _ClassVar[int]
+ IP_FIELD_NUMBER: _ClassVar[int]
+ rank: Rank
+ port: Port
+ ip: str
+ def __init__(self, rank: _Optional[_Union[Rank, _Mapping]] = ..., port: _Optional[_Union[Port, _Mapping]] = ..., ip: _Optional[str] = ...) -> None: ...
+
+class PeerIds(_message.Message):
+ __slots__ = ("peer_ids",)
+ class PeerIdsEntry(_message.Message):
+ __slots__ = ("key", "value")
+ KEY_FIELD_NUMBER: _ClassVar[int]
+ VALUE_FIELD_NUMBER: _ClassVar[int]
+ key: int
+ value: PeerId
+ def __init__(self, key: _Optional[int] = ..., value: _Optional[_Union[PeerId, _Mapping]] = ...) -> None: ...
+ PEER_IDS_FIELD_NUMBER: _ClassVar[int]
+ peer_ids: _containers.MessageMap[int, PeerId]
+ def __init__(self, peer_ids: _Optional[_Mapping[int, PeerId]] = ...) -> None: ...
+
+class Quorum(_message.Message):
+ __slots__ = ("quorum",)
+ QUORUM_FIELD_NUMBER: _ClassVar[int]
+ quorum: bool
+ def __init__(self, quorum: bool = ...) -> None: ...
diff --git a/src/utils/communication/grpc/comm_pb2_grpc.py b/src/utils/communication/grpc/comm_pb2_grpc.py
new file mode 100644
index 0000000..feb7eb8
--- /dev/null
+++ b/src/utils/communication/grpc/comm_pb2_grpc.py
@@ -0,0 +1,274 @@
+# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
+"""Client and server classes corresponding to protobuf-defined services."""
+import grpc
+import warnings
+
+import comm_pb2 as comm__pb2
+
+GRPC_GENERATED_VERSION = '1.64.0'
+GRPC_VERSION = grpc.__version__
+EXPECTED_ERROR_RELEASE = '1.65.0'
+SCHEDULED_RELEASE_DATE = 'June 25, 2024'
+_version_not_supported = False
+
+try:
+ from grpc._utilities import first_version_is_lower
+ _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
+except ImportError:
+ _version_not_supported = True
+
+if _version_not_supported:
+ warnings.warn(
+ f'The grpc package installed is at version {GRPC_VERSION},'
+ + f' but the generated code in comm_pb2_grpc.py depends on'
+ + f' grpcio>={GRPC_GENERATED_VERSION}.'
+ + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+ + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
+ + f' This warning will become an error in {EXPECTED_ERROR_RELEASE},'
+ + f' scheduled for release on {SCHEDULED_RELEASE_DATE}.',
+ RuntimeWarning
+ )
+
+
+class CommunicationServerStub(object):
+ """Missing associated documentation comment in .proto file."""
+
+ def __init__(self, channel):
+ """Constructor.
+
+ Args:
+ channel: A grpc.Channel.
+ """
+ self.send_data = channel.unary_unary(
+ '/CommunicationServer/send_data',
+ request_serializer=comm__pb2.Data.SerializeToString,
+ response_deserializer=comm__pb2.Empty.FromString,
+ _registered_method=True)
+ self.get_rank = channel.unary_unary(
+ '/CommunicationServer/get_rank',
+ request_serializer=comm__pb2.Empty.SerializeToString,
+ response_deserializer=comm__pb2.Rank.FromString,
+ _registered_method=True)
+ self.update_port = channel.unary_unary(
+ '/CommunicationServer/update_port',
+ request_serializer=comm__pb2.PeerId.SerializeToString,
+ response_deserializer=comm__pb2.Empty.FromString,
+ _registered_method=True)
+ self.send_peer_ids = channel.unary_unary(
+ '/CommunicationServer/send_peer_ids',
+ request_serializer=comm__pb2.PeerIds.SerializeToString,
+ response_deserializer=comm__pb2.Empty.FromString,
+ _registered_method=True)
+ self.send_quorum = channel.unary_unary(
+ '/CommunicationServer/send_quorum',
+ request_serializer=comm__pb2.Quorum.SerializeToString,
+ response_deserializer=comm__pb2.Empty.FromString,
+ _registered_method=True)
+
+
+class CommunicationServerServicer(object):
+ """Missing associated documentation comment in .proto file."""
+
+ def send_data(self, request, context):
+ """Missing associated documentation comment in .proto file."""
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+ def get_rank(self, request, context):
+ """Missing associated documentation comment in .proto file."""
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+ def update_port(self, request, context):
+ """Missing associated documentation comment in .proto file."""
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+ def send_peer_ids(self, request, context):
+ """Missing associated documentation comment in .proto file."""
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+ def send_quorum(self, request, context):
+ """Missing associated documentation comment in .proto file."""
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+
+def add_CommunicationServerServicer_to_server(servicer, server):
+ rpc_method_handlers = {
+ 'send_data': grpc.unary_unary_rpc_method_handler(
+ servicer.send_data,
+ request_deserializer=comm__pb2.Data.FromString,
+ response_serializer=comm__pb2.Empty.SerializeToString,
+ ),
+ 'get_rank': grpc.unary_unary_rpc_method_handler(
+ servicer.get_rank,
+ request_deserializer=comm__pb2.Empty.FromString,
+ response_serializer=comm__pb2.Rank.SerializeToString,
+ ),
+ 'update_port': grpc.unary_unary_rpc_method_handler(
+ servicer.update_port,
+ request_deserializer=comm__pb2.PeerId.FromString,
+ response_serializer=comm__pb2.Empty.SerializeToString,
+ ),
+ 'send_peer_ids': grpc.unary_unary_rpc_method_handler(
+ servicer.send_peer_ids,
+ request_deserializer=comm__pb2.PeerIds.FromString,
+ response_serializer=comm__pb2.Empty.SerializeToString,
+ ),
+ 'send_quorum': grpc.unary_unary_rpc_method_handler(
+ servicer.send_quorum,
+ request_deserializer=comm__pb2.Quorum.FromString,
+ response_serializer=comm__pb2.Empty.SerializeToString,
+ ),
+ }
+ generic_handler = grpc.method_handlers_generic_handler(
+ 'CommunicationServer', rpc_method_handlers)
+ server.add_generic_rpc_handlers((generic_handler,))
+ server.add_registered_method_handlers('CommunicationServer', rpc_method_handlers)
+
+
+ # This class is part of an EXPERIMENTAL API.
+class CommunicationServer(object):
+ """Missing associated documentation comment in .proto file."""
+
+ @staticmethod
+ def send_data(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_unary(
+ request,
+ target,
+ '/CommunicationServer/send_data',
+ comm__pb2.Data.SerializeToString,
+ comm__pb2.Empty.FromString,
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ _registered_method=True)
+
+ @staticmethod
+ def get_rank(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_unary(
+ request,
+ target,
+ '/CommunicationServer/get_rank',
+ comm__pb2.Empty.SerializeToString,
+ comm__pb2.Rank.FromString,
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ _registered_method=True)
+
+ @staticmethod
+ def update_port(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_unary(
+ request,
+ target,
+ '/CommunicationServer/update_port',
+ comm__pb2.PeerId.SerializeToString,
+ comm__pb2.Empty.FromString,
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ _registered_method=True)
+
+ @staticmethod
+ def send_peer_ids(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_unary(
+ request,
+ target,
+ '/CommunicationServer/send_peer_ids',
+ comm__pb2.PeerIds.SerializeToString,
+ comm__pb2.Empty.FromString,
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ _registered_method=True)
+
+ @staticmethod
+ def send_quorum(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_unary(
+ request,
+ target,
+ '/CommunicationServer/send_quorum',
+ comm__pb2.Quorum.SerializeToString,
+ comm__pb2.Empty.FromString,
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ _registered_method=True)
diff --git a/src/utils/communication/grpc/grpc_utils.py b/src/utils/communication/grpc/grpc_utils.py
new file mode 100644
index 0000000..64d8e6b
--- /dev/null
+++ b/src/utils/communication/grpc/grpc_utils.py
@@ -0,0 +1,18 @@
+from collections import OrderedDict
+import io
+import torch
+
+def serialize_model(state_dict: OrderedDict[str, torch.Tensor]) -> bytes:
+ # put every parameter on cpu first
+ for key in state_dict.keys():
+ state_dict[key] = state_dict[key].to('cpu')
+ buffer = io.BytesIO()
+ torch.save(state_dict, buffer) # type: ignore
+ buffer.seek(0)
+ return buffer.read()
+
+def deserialize_model(model_bytes: bytes) -> OrderedDict[str, torch.Tensor]:
+ buffer = io.BytesIO(model_bytes)
+ buffer.seek(0)
+ model_wts = torch.load(buffer) # type: ignore
+ return model_wts
diff --git a/src/utils/communication/grpc/main.py b/src/utils/communication/grpc/main.py
new file mode 100644
index 0000000..82695b1
--- /dev/null
+++ b/src/utils/communication/grpc/main.py
@@ -0,0 +1,282 @@
+from concurrent import futures
+from queue import Queue
+import random
+import re
+import threading
+import time
+import socket
+from typing import Any, Dict, List, OrderedDict, Union
+from urllib.parse import unquote
+import grpc # type: ignore
+from utils.communication.grpc.grpc_utils import deserialize_model, serialize_model
+import os
+import sys
+
+grpc_generated_dir = os.path.dirname(os.path.abspath(__file__))
+if grpc_generated_dir not in sys.path:
+ sys.path.append(grpc_generated_dir)
+
+import comm_pb2 as comm_pb2
+import comm_pb2_grpc as comm_pb2_grpc
+from utils.communication.interface import CommunicationInterface
+
+
+# TODO: Several changes needed to improve the quality of the code
+# 1. We need to improve comm.proto and get rid of singletons like Rank, Port etc.
+# 2. Some parts of the code are heavily nested and need to be refactored
+# 3. Insert try-except blocks wherever communication is involved
+# 4. Probably a good idea to move the Servicer class to a separate file
+# 5. Not needed for benchmarking but for the system to be robust, we need to implement timeouts and fault tolerance
+# 6. Peer_ids should be indexed by a unique identifier
+# 7. Try to get rid of type: ignore as much as possible
+
+def is_port_available(port: int) -> bool:
+ """
+ Check if a port is available for use.
+ """
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: # type: ignore
+ return s.connect_ex(('localhost', port)) != 0 # type: ignore
+
+def get_port(rank: int, num_users: int) -> int:
+ """
+ Get the port number for the given rank.
+ """
+ start = 50051
+ while True:
+ port = start + rank
+ if is_port_available(port):
+ return port
+ # if we increment by 1 then it's likely that
+ # the next node will also have the same port number
+ start += num_users
+ if start > 65535:
+ raise Exception(f"No available ports for node {rank}")
+
+def parse_peer_address(peer: str) -> str:
+ # Remove 'ipv4:' or 'ipv6:' prefix
+ if peer.startswith(('ipv4:', 'ipv6:')):
+ peer = peer.split(':', 1)[1]
+
+ # Handle IPv6 address
+ if peer.startswith('['):
+ # Extract IPv6 address
+ match = re.match(r'\[([^\]]+)\]', peer)
+ if match:
+ return unquote(match.group(1)) # Decode URL-encoded characters
+ else:
+ # Handle IPv4 address or hostname
+ return peer.rsplit(':', 1)[0] # Remove port number
+ return ""
+
+class Servicer(comm_pb2_grpc.CommunicationServerServicer):
+ def __init__(self, super_node_host: str):
+ self.lock = threading.Lock()
+ self.condition = threading.Condition(self.lock)
+ self.received_data: Queue[Any] = Queue()
+ self.quorum: Queue[bool] = Queue()
+ port = int(super_node_host.split(":")[1])
+ ip = super_node_host.split(":")[0]
+ self.peer_ids: OrderedDict[int, Dict[str, int|str]] = OrderedDict({
+ 0: {"rank": 0, "port": port, "ip": ip}
+ })
+
+ def send_data(self, request, context) -> comm_pb2.Empty: # type: ignore
+ self.received_data.put(deserialize_model(request.model.buffer)) # type: ignore
+ return comm_pb2.Empty() # type: ignore
+
+ def get_rank(self, request: comm_pb2.Empty, context: grpc.ServicerContext) -> comm_pb2.Rank: # type: ignore
+ try:
+ with self.lock:
+ peer = context.peer() # type: ignore
+ # parse the hostname from peer
+ peer_str = parse_peer_address(peer) # type: ignore
+ rank = len(self.peer_ids)
+ # TODO: index the peer_ids by a unique identifier
+ self.peer_ids[rank] = {"rank": rank, "port": 0, "ip": peer_str}
+ rank = self.peer_ids[rank].get("rank", -1) # Default to -1 if not found
+ return comm_pb2.Rank(rank=rank) # type: ignore
+ except Exception as e:
+ context.abort(grpc.StatusCode.INTERNAL, f"Error in get_rank: {str(e)}") # type: ignore
+
+ def update_port(self, request: comm_pb2.PeerIds, context: grpc.ServicerContext) -> comm_pb2.Empty:
+ with self.lock:
+ self.peer_ids[request.rank.rank]["port"] = request.port.port # type: ignore
+ return comm_pb2.Empty() # type: ignore
+
+ def send_peer_ids(self, request, context) -> comm_pb2.Empty: # type: ignore
+ """
+ Used by the super node to update all peers with the peer_ids
+ after achieving quorum.
+ """
+ peer_ids: comm_pb2.PeerIds = request.peer_ids # type: ignore
+ for rank in peer_ids: # type: ignore
+ peer_id_proto = peer_ids[rank] # type: ignore
+ peer_id_dict: Dict[str, Union[int, str]] = {
+ "rank": peer_id_proto.rank.rank, # type: ignore
+ "port": peer_id_proto.port.port, # type: ignore
+ "ip": peer_id_proto.ip # type: ignore
+ }
+ self.peer_ids[rank] = peer_id_dict
+ return comm_pb2.Empty()
+
+ def send_quorum(self, request, context) -> comm_pb2.Empty: # type: ignore
+ self.quorum.put(request.quorum) # type: ignore
+ return comm_pb2.Empty() # type: ignore
+
+
+class GRPCCommunication(CommunicationInterface):
+ def __init__(self, config: Dict[str, Dict[str, Any]]):
+ # TODO: Implement this differently later by creating a super node
+ # that maintains a list of all peers
+ # all peers will send their IDs to this super node
+ # when they start.
+ # The implementation will have
+ # 1. Registration phase where every peer registers itself and gets a rank
+ # 2. Once a threshold number of peers have registered, the super node sets quorum to True
+ # 3. The super node broadcasts the peer_ids to all peers
+ # 4. The nodes will execute rest of the protocol in the same way as before
+ self.num_users: int = int(config["num_users"]) # type: ignore
+ self.rank: int|None = config["comm"]["rank"]
+ self.super_node_host: str = config["comm"]["peer_ids"][0]
+ if self.rank == 0:
+ node_id: List[str] = self.super_node_host.split(":")
+ self.host: str = node_id[0]
+ self.port: int = int(node_id[1])
+ self.listener: Any = None
+ self.servicer = Servicer(self.super_node_host)
+
+ @staticmethod
+ def get_registered_users(peer_ids: OrderedDict[int, Dict[str, int|str]]) -> int:
+ # count the number of entries that have a non-zero port
+ return len([peer_id for peer_id, values in peer_ids.items() if values.get("port") != 0])
+
+ def register(self):
+ with grpc.insecure_channel(self.super_node_host) as channel: # type: ignore
+ stub = comm_pb2_grpc.CommunicationServerStub(channel)
+ max_tries = 10
+ while max_tries > 0:
+ try:
+ self.rank = stub.get_rank(comm_pb2.Empty()).rank # type: ignore
+ break
+ except grpc.RpcError as e:
+ print(f"RPC failed: {e}", "Retrying...")
+ # sleep for a random time between 1 and 10 seconds
+ random_time = random.randint(1, 10)
+ time.sleep(random_time)
+ max_tries -= 1
+ self.port = get_port(self.rank, self.num_users) # type: ignore because we are setting it in the register method
+ rank = comm_pb2.Rank(rank=self.rank) # type: ignore
+ port = comm_pb2.Port(port=self.port)
+ peer_id = comm_pb2.PeerId(rank=rank, port=port)
+ stub.update_port(peer_id) # type: ignore
+
+ def start_listener(self):
+ self.listener = grpc.server(futures.ThreadPoolExecutor(max_workers=4), options=[ # type: ignore
+ ('grpc.max_send_message_length', 100 * 1024 * 1024), # 100MB
+ ('grpc.max_receive_message_length', 100 * 1024 * 1024), # 100MB
+ ])
+ comm_pb2_grpc.add_CommunicationServerServicer_to_server(self.servicer, self.listener) # type: ignore
+ self.listener.add_insecure_port(f'[::]:{self.port}')
+ self.listener.start()
+ print(f'Started listener on port {self.port}')
+
+ def peer_ids_to_proto(self, peer_ids: OrderedDict[int, Dict[str, int|str]]) -> Dict[int, comm_pb2.PeerId]:
+ peer_ids_proto: Dict[int, comm_pb2.PeerId] = {}
+ for peer_id in peer_ids:
+ rank = comm_pb2.Rank(rank=peer_ids[peer_id].get("rank")) # type: ignore
+ port = comm_pb2.Port(port=peer_ids[peer_id].get("port")) # type: ignore
+ ip = str(peer_ids[peer_id].get("ip"))
+ peer_ids_proto[peer_id] = comm_pb2.PeerId(rank=rank, port=port, ip=ip)
+ return peer_ids_proto
+
+ def initialize(self):
+ if self.rank != 0:
+ self.register()
+
+ self.start_listener()
+
+ # wait for the quorum to be set
+ if self.rank != 0:
+ status = self.servicer.quorum.get()
+ if not status:
+ print("Quorum became false!")
+ sys.exit(1)
+ else:
+ quorum_threshold = self.num_users + 1 # +1 for the super node
+ num_registered = self.get_registered_users(self.servicer.peer_ids)
+ while num_registered < quorum_threshold:
+ # sleep for 5 seconds
+ print(f"Waiting for {quorum_threshold} users to register, {num_registered} have registered so far")
+ time.sleep(5)
+ num_registered = self.get_registered_users(self.servicer.peer_ids)
+ # TODO: Implement a timeout here and if the timeout is reached
+ # then set quorum to False for all registered users
+ # and exit the program
+
+ print("All users have registered", self.servicer.peer_ids)
+ for peer_id in self.servicer.peer_ids:
+ host_ip = self.servicer.peer_ids[peer_id].get("ip")
+ own_ip = self.servicer.peer_ids[0].get("ip")
+ if host_ip != own_ip:
+ port = self.servicer.peer_ids[peer_id].get("port")
+ address = f"{host_ip}:{port}"
+ with grpc.insecure_channel(address) as channel: # type: ignore
+ stub = comm_pb2_grpc.CommunicationServerStub(channel)
+ proto_msg = comm_pb2.PeerIds(peer_ids=self.peer_ids_to_proto(self.servicer.peer_ids))
+ stub.send_peer_ids(proto_msg) # type: ignore
+ stub.send_quorum(comm_pb2.Quorum(quorum=True)) # type: ignore
+
+ def get_host_from_rank(self, rank: int) -> str:
+ for peer_id in self.servicer.peer_ids:
+ if self.servicer.peer_ids[peer_id].get("rank") == rank:
+ return self.servicer.peer_ids[peer_id].get("ip") + ":" + str(self.servicer.peer_ids[peer_id].get("port")) # type: ignore
+ raise Exception(f"Rank {rank} not found in peer_ids")
+
+ def send(self, dest: str|int, data: OrderedDict[str, Any]):
+ """
+ data should be a torch model
+ """
+ dest_host: str = ""
+ if type(dest) == int:
+ dest_host = self.get_host_from_rank(dest)
+ else:
+ dest_host = str(dest)
+ try:
+ buffer = serialize_model(data)
+ with grpc.insecure_channel(dest_host) as channel: # type: ignore
+ stub = comm_pb2_grpc.CommunicationServerStub(channel) # type: ignore
+ model = comm_pb2.Model(buffer=buffer) # type: ignore
+ stub.send_data(comm_pb2.Data(model=model, id=str(self.rank))) # type: ignore
+ except grpc.RpcError as e:
+ print(f"RPC failed: {e}")
+ sys.exit(1)
+
+ def receive(self, node_ids: str|int) -> Any:
+ # this .get() will block until
+ # at least 1 item is received in the queue
+ return self.servicer.received_data.get()
+
+ def is_own_id(self, peer_id: int) -> bool:
+ rank = self.servicer.peer_ids[peer_id].get("rank")
+ if rank != self.rank:
+ return False
+ return True
+
+ def broadcast(self, data: Any):
+ for peer_id in self.servicer.peer_ids:
+ if not self.is_own_id(peer_id):
+ self.send(peer_id, data)
+
+ def all_gather(self) -> Any:
+ # this will block until all items are received
+ # from all peers
+ items: List[Any] = []
+ for peer_id in self.servicer.peer_ids:
+ if not self.is_own_id(peer_id):
+ items.append(self.receive(peer_id))
+ return items
+
+ def finalize(self):
+ if self.listener:
+ self.listener.stop(0)
+ print(f'Stopped server on port {self.port}')
diff --git a/src/utils/communication/interface.py b/src/utils/communication/interface.py
new file mode 100644
index 0000000..b5ab2ba
--- /dev/null
+++ b/src/utils/communication/interface.py
@@ -0,0 +1,30 @@
+from abc import ABC, abstractmethod
+from typing import Any
+
+class CommunicationInterface(ABC):
+ def __init__(self):
+ pass
+
+ @abstractmethod
+ def initialize(self):
+ pass
+
+ @abstractmethod
+ def send(self, dest: str|int, data: Any):
+ pass
+
+ @abstractmethod
+ def receive(self, node_ids: str|int) -> Any:
+ pass
+
+ @abstractmethod
+ def broadcast(self, data: Any):
+ pass
+
+ @abstractmethod
+ def all_gather(self) -> Any:
+ pass
+
+ @abstractmethod
+ def finalize(self):
+ pass
diff --git a/src/utils/communication/mpi.py b/src/utils/communication/mpi.py
new file mode 100644
index 0000000..ea5bad8
--- /dev/null
+++ b/src/utils/communication/mpi.py
@@ -0,0 +1,35 @@
+from typing import Dict, Any, List
+from mpi4py import MPI
+from utils.communication.interface import CommunicationInterface
+
+class MPICommUtils(CommunicationInterface):
+ def __init__(self, config: Dict[str, Dict[str, Any]]):
+ self.comm = MPI.COMM_WORLD
+ self.rank = self.comm.Get_rank()
+ self.size = self.comm.Get_size()
+
+ def initialize(self):
+ pass
+
+ def send(self, dest: str|int, data: Any):
+ self.comm.send(data, dest=int(dest))
+
+ def receive(self, node_ids: str|int) -> Any:
+ return self.comm.recv(source=int(node_ids))
+
+ def broadcast(self, data: Any):
+ for i in range(1, self.size):
+ if i != self.rank:
+ self.send(i, data)
+
+ def all_gather(self):
+ """
+ This function is used to gather data from all the nodes.
+ """
+ items: List[Any] = []
+ for i in range(1, self.size):
+ items.append(self.receive(i))
+ return items
+
+ def finalize(self):
+ pass
diff --git a/src/utils/config_utils.py b/src/utils/config_utils.py
index b3cbc8a..51407d9 100644
--- a/src/utils/config_utils.py
+++ b/src/utils/config_utils.py
@@ -1,8 +1,9 @@
+from typing import Any, Dict
import jmespath
import importlib
-def load_config(config_path):
+def load_config(config_path: str) -> Dict[str, Any]:
path = ".".join(config_path.split(".")[1].split("/")[1:])
config = importlib.import_module(path).current_config
return config
diff --git a/src/utils/data_utils.py b/src/utils/data_utils.py
index d3ea63b..4fbe2ef 100644
--- a/src/utils/data_utils.py
+++ b/src/utils/data_utils.py
@@ -1,4 +1,5 @@
import importlib
+from typing import Tuple
import numpy as np
import torch
import torchvision.transforms as T
@@ -93,7 +94,7 @@ def filter_by_class(dataset, classes):
return Subset(dataset, indices), indices
-def random_samples(dataset, num_samples):
+def random_samples(dataset, num_samples) -> Tuple[Subset, np.ndarray]:
"""
Returns a random subset of samples from the dataset.
"""
@@ -110,7 +111,7 @@ def extr_noniid(train_dataset, samples_per_client, classes):
return Subset(all_data, perm[:samples_per_client])
-def cifar_extr_noniid(
+def cifar_extr_noniid(
train_dataset, test_dataset, num_users, n_class, num_samples, rate_unbalance
):
"""
diff --git a/src/utils/log_utils.py b/src/utils/log_utils.py
index 3dd6612..96afea2 100644
--- a/src/utils/log_utils.py
+++ b/src/utils/log_utils.py
@@ -9,11 +9,11 @@
import sys
from glob import glob
from shutil import copytree, copy2
-from PIL import Image
+from typing import Any, Dict
import torch
-import torchvision.transforms as T
-from torchvision.utils import make_grid, save_image
-from tensorboardX import SummaryWriter
+import torchvision.transforms as T # type: ignore
+from torchvision.utils import make_grid, save_image # type: ignore
+from tensorboardX import SummaryWriter # type: ignore
import numpy as np
@@ -36,7 +36,7 @@ def deprocess(img):
return img.type(torch.uint8)
-def check_and_create_path(path, folder_deletion_path):
+def check_and_create_path(path: str, folder_deletion_path: str|None=None):
"""
Checks if the specified path exists and prompts the user for action if it does.
Creates the directory if it does not exist.
@@ -45,26 +45,11 @@ def check_and_create_path(path, folder_deletion_path):
path (str): Path to check and create if necessary.
"""
if os.path.isdir(path):
- print(f"Experiment in {path} already present")
- done = False
- while not done:
- # Color the input prompt
- color_code = "\033[94m" # Blue text
- reset_code = "\033[0m" # Reset to default color
- # Highlighted prompt in blue
- inp = input(f"{color_code}Press e to exit, r to replace it: {reset_code}")
-
- if inp == "e":
- sys.exit()
- elif inp == "r":
- done = True
- shutil.rmtree(path)
- os.makedirs(path)
- with open(folder_deletion_path, "w") as signal_file:
- #Folder deletion complete signal.
- signal_file.write("r")
- else:
- print("Input not understood")
+ color_code = "\033[94m" # Blue text
+ reset_code = "\033[0m" # Reset to default color
+ print(f"{color_code}Experiment in {path} already present. Exiting.")
+ print(f"Please do: rm -rf {path} to delete the folder.{reset_code}")
+ sys.exit()
else:
os.makedirs(path)
with open(folder_deletion_path, "w") as signal_file:
@@ -72,7 +57,7 @@ def check_and_create_path(path, folder_deletion_path):
signal_file.write("new")
-def copy_source_code(config: dict) -> None:
+def copy_source_code(config: Dict[str, Any]) -> None:
"""
Copy source code to experiment folder for reproducibility.
@@ -120,7 +105,7 @@ class LogUtils:
"""
Utility class for logging and saving experiment data.
"""
- def __init__(self, config) -> None:
+ def __init__(self, config: Dict[str, Any]) -> None:
log_dir = config["log_path"]
load_existing = config["load_existing"]
log_format = (
@@ -144,7 +129,7 @@ def init_summary(self):
"""
self.summary_file = open(f"{self.log_dir}/summary.txt", "w", encoding="utf-8")
- def init_tb(self, load_existing):
+ def init_tb(self, load_existing: bool):
"""
Initialize TensorBoard logging.
@@ -164,7 +149,7 @@ def init_npy(self):
if not os.path.exists(npy_path) or not os.path.isdir(npy_path):
os.makedirs(npy_path)
- def log_summary(self, text):
+ def log_summary(self, text: str):
"""
Add summary text to the summary file for logging.
"""
@@ -174,7 +159,7 @@ def log_summary(self, text):
else:
raise ValueError("Summary file is not initialized. Call init_summary() first.")
- def log_image(self, imgs: torch.Tensor, key, iteration):
+ def log_image(self, imgs: torch.Tensor, key: str, iteration: int):
"""
Log image to file and TensorBoard.
@@ -187,7 +172,7 @@ def log_image(self, imgs: torch.Tensor, key, iteration):
save_image(grid_img, f"{self.log_dir}/{iteration}_{key}.png")
self.writer.add_image(key, grid_img.numpy(), iteration)
- def log_console(self, msg):
+ def log_console(self, msg: str):
"""
Log a message to the console.
@@ -196,7 +181,7 @@ def log_console(self, msg):
"""
logging.info(msg)
- def log_tb(self, key, value, iteration):
+ def log_tb(self, key: str, value: float|int, iteration: int):
"""
Log a scalar value to TensorBoard.
@@ -205,9 +190,9 @@ def log_tb(self, key, value, iteration):
value (float): Value to log.
iteration (int): Current iteration number.
"""
- self.writer.add_scalar(key, value, iteration)
+ self.writer.add_scalar(key, value, iteration) # type: ignore
- def log_npy(self, key, value):
+ def log_npy(self, key: str, value: np.ndarray):
"""
Save a numpy array to file.
@@ -217,7 +202,7 @@ def log_npy(self, key, value):
"""
np.save(f"{self.log_dir}/npy/{key}.npy", value)
- def log_max_stats_per_client(self, stats_per_client, round_step, metric):
+ def log_max_stats_per_client(self, stats_per_client: np.ndarray, round_step: int, metric: str):
"""
Log maximum statistics per client.
diff --git a/src/utils/model_utils.py b/src/utils/model_utils.py
index 8441872..14ea93b 100644
--- a/src/utils/model_utils.py
+++ b/src/utils/model_utils.py
@@ -1,9 +1,10 @@
from collections import OrderedDict
-from typing import Tuple, List
+from typing import Any, Tuple, List
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
+from torch.utils.data import DataLoader
from torch.nn.parallel import DataParallel
import resnet
@@ -13,7 +14,8 @@
class ModelUtils():
- def __init__(self) -> None:
+ def __init__(self, device: torch.device) -> None:
+ self.device = device
self.models_layers_idx = {
"resnet10": {
@@ -33,14 +35,14 @@ def __init__(self) -> None:
}
def get_model(
- self,
- model_name: str,
- dset: str,
- device: torch.device,
- device_ids: list,
- pretrained=False,
- **kwargs) -> DataParallel:
- self.dset = dset # cifar10, wilds, pascal
+ self,
+ model_name: str,
+ dset: str,
+ device: torch.device,
+ device_ids: List[int],
+ pretrained:bool=False,
+ **kwargs: Any
+ ) -> nn.Module:
# TODO: add support for loading checkpointed models
model_name = model_name.lower()
if model_name == "resnet10":
@@ -65,12 +67,12 @@ def get_model(
def train(self,
model: nn.Module,
- optim,
- dloader,
- loss_fn,
+ optim: torch.optim.Optimizer,
+ dloader: DataLoader[Any],
+ loss_fn: Any,
device: torch.device,
- test_loader=None,
- **kwargs) -> Tuple[float,
+ test_loader: DataLoader[Any]|None=None,
+ **kwargs: Any) -> Tuple[float,
float]:
"""TODO: generate docstring
"""
@@ -291,8 +293,8 @@ def deep_mutual_train(self, models, optim, dloader,
train_accuracy[i] = 100. * correct[i] / len(dloader.dataset)
return train_loss, train_accuracy
- def test(self, model, dloader, loss_fn, device,
- **kwargs) -> Tuple[float, float]:
+ def test(self, model: nn.Module, dloader: DataLoader[Any], loss_fn: Any, device: torch.device,
+ **kwargs: Any) -> Tuple[float, float]:
"""TODO: generate docstring
"""
model.eval()
@@ -362,7 +364,7 @@ def test_classification(self, model, dloader, loss_fn, device, **kwargs) -> Tupl
acc = correct / len(dloader.dataset)
return test_loss, acc
- def save_model(self, model, path):
+ def save_model(self, model: nn.Module, path: str) -> None:
if isinstance(model, DataParallel):
model_ = model.module
else:
@@ -403,3 +405,9 @@ def filter_model_weights(
if key not in key_to_ignore:
filtered_model_wts[key] = param
return filtered_model_wts
+
+ def get_memory_usage(self):
+ """
+ Get the memory usage
+ """
+ return torch.cuda.memory_allocated(self.device)