diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..9e6483a --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "python.analysis.typeCheckingMode": "strict" +} \ No newline at end of file diff --git a/src/README.md b/src/README.md index 0e66197..49136fe 100644 --- a/src/README.md +++ b/src/README.md @@ -31,3 +31,8 @@ Client Metrics Server Metrics + +### Debugging instructions +GRPC simulation starts a lot of threads and even if one of them fail right now then you will have to kill all of them and start all over. +So, here is a command to get the pid of all the threads and kill them all at once: +`for pid in $(ps aux|grep 'python main.py -r' | cut -b 10-16); do kill -9 $pid; done` \ No newline at end of file diff --git a/src/algos/base_class.py b/src/algos/base_class.py index ba83a71..049d8c9 100644 --- a/src/algos/base_class.py +++ b/src/algos/base_class.py @@ -4,14 +4,14 @@ from torch.utils.data import DataLoader, Subset from collections import OrderedDict -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple from torch import Tensor import copy import random import numpy as np +from utils.communication.comm_utils import CommunicationManager from utils.plot_utils import PlotUtils -from utils.comm_utils import CommUtils from utils.data_utils import ( random_samples, filter_by_class, @@ -28,15 +28,15 @@ get_dset_balanced_communities, get_dset_communities, ) -import torchvision.transforms as T +import torchvision.transforms as T # type: ignore import os from yolo import YOLOLoss class BaseNode(ABC): - def __init__(self, config) -> None: - self.comm_utils = CommUtils() - self.node_id = self.comm_utils.rank + def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None: + self.comm_utils = comm_utils + self.node_id = self.comm_utils.get_rank() if self.node_id == 0: self.log_dir = config['log_path'] @@ -54,13 +54,13 @@ def __init__(self, config) -> None: if isinstance(config["dset"], dict): if self.node_id != 0: config["dset"].pop("0") - self.dset = config["dset"][str(self.node_id)] + self.dset = str(config["dset"][str(self.node_id)]) config["dpath"] = config["dpath"][self.dset] else: self.dset = config["dset"] self.setup_cuda(config) - self.model_utils = ModelUtils() + self.model_utils = ModelUtils(self.device) self.dset_obj = get_dataset(self.dset, dpath=config["dpath"]) self.set_constants() @@ -68,7 +68,7 @@ def __init__(self, config) -> None: def set_constants(self): self.best_acc = 0.0 - def setup_cuda(self, config): + def setup_cuda(self, config: Dict[str, Any]): # Need a mapping from rank to device id device_ids_map = config["device_ids"] node_name = "node_{}".format(self.node_id) @@ -82,7 +82,7 @@ def setup_cuda(self, config): self.device = torch.device("cpu") print("Using CPU") - def set_model_parameters(self, config): + def set_model_parameters(self, config: Dict[str, Any]): # Model related parameters optim_name = config.get("optimizer", "adam") if optim_name == "adam": @@ -149,7 +149,7 @@ def set_shared_exp_parameters(self, config): self.log_utils.log_console("Communities: {}".format(self.communities)) @abstractmethod - def run_protocol(self): + def run_protocol(self) -> None: raise NotImplementedError @@ -158,8 +158,8 @@ class BaseClient(BaseNode): Abstract class for all algorithms """ - def __init__(self, config) -> None: - super().__init__(config) + def __init__(self, config, comm_utils) -> None: + super().__init__(config, comm_utils) self.server_node = 0 self.set_parameters(config) @@ -215,8 +215,8 @@ def set_data_parameters(self, config): train_dset = self.dset_obj.train_dset test_dset = self.dset_obj.test_dset - print("num train", len(train_dset)) - print("num test", len(test_dset)) + # print("num train", len(train_dset)) + # print("num test", len(test_dset)) if config.get("test_samples_per_class", None) is not None: test_dset, _ = balanced_subset(test_dset, config["test_samples_per_class"]) @@ -369,19 +369,19 @@ def is_same_dest(dset): # TODO: fix print_data_summary # self.print_data_summary(train_dset, test_dset, val_dset=val_dset) - def local_train(self, dataset, **kwargs): + def local_train(self, round: int, **kwargs: Any) -> None: """ Train the model locally """ raise NotImplementedError - def local_test(self, dataset, **kwargs): + def local_test(self, **kwargs: Any) -> float | Tuple[float, float] | None: """ Test the model locally """ raise NotImplementedError - def get_representation(self, **kwargs): + def get_representation(self, **kwargs: Any) -> OrderedDict[str, Tensor] | List[Tensor] | Tensor: """ Share the model representation """ @@ -416,21 +416,17 @@ def print_data_summary(self, train_test, test_dset, val_dset=None): print("test count: ", i) i += 1 - print("Node: {} data distribution summary".format(self.node_id)) - print(type(train_sample_per_class.items())) - print( - "Train samples per class: {}".format(sorted(train_sample_per_class.items())) - ) - print( - "Train samples per class: {}".format(len(train_sample_per_class.items())) - ) - if val_dset is not None: - print( - "Val samples per class: {}".format(len(val_sample_per_class.items())) - ) - print( - "Test samples per class: {}".format(len(test_sample_per_class.items())) - ) + # print("Node: {} data distribution summary".format(self.node_id)) + # print( + # "Train samples per class: {}".format(sorted(train_sample_per_class.items())) + # ) + # if val_dset is not None: + # print( + # "Val samples per class: {}".format(sorted(val_sample_per_class.items())) + # ) + # print( + # "Test samples per class: {}".format(sorted(test_sample_per_class.items())) + # ) class BaseServer(BaseNode): @@ -438,8 +434,8 @@ class BaseServer(BaseNode): Abstract class for orchestrator """ - def __init__(self, config) -> None: - super().__init__(config) + def __init__(self, config, comm_utils) -> None: + super().__init__(config, comm_utils) self.num_users = config["num_users"] self.users = list(range(1, self.num_users + 1)) self.set_data_parameters(config) @@ -449,13 +445,13 @@ def set_data_parameters(self, config): batch_size = config["batch_size"] self._test_loader = DataLoader(test_dset, batch_size=batch_size) - def aggregate(self, representation_list, **kwargs): + def aggregate(self, representation_list: List[OrderedDict[str, Tensor]], **kwargs: Any) -> OrderedDict[str, Tensor]: """ Aggregate the knowledge from the users """ raise NotImplementedError - def test(self, dataset, **kwargs): + def test(self, **kwargs: Any) -> List[float]: """ Test the model on the server """ @@ -668,10 +664,5 @@ def __init__(self, config, comm_protocol=CommProtocol) -> None: super().__init__(config) self.tag = comm_protocol - def send_representations(self, representations, tag=None): - for user_node in self.users: - self.comm_utils.send_signal( - dest=user_node, - data=representations, - tag=self.tag.REPRS_SHARE if tag is None else tag, - ) + def send_representations(self, representations: Dict[int, OrderedDict[str, Tensor]]): + self.comm_utils.broadcast(representations) diff --git a/src/algos/fl.py b/src/algos/fl.py index 781c417..94eb368 100644 --- a/src/algos/fl.py +++ b/src/algos/fl.py @@ -1,48 +1,32 @@ from collections import OrderedDict +import sys from typing import Any, Dict, List from torch import Tensor -import torch.nn as nn +from utils.communication.comm_utils import CommunicationManager from utils.log_utils import LogUtils from algos.base_class import BaseClient, BaseServer import os import time -class CommProtocol(object): - """ - Communication protocol tags for the server and clients - """ - - DONE = 0 # Used to signal that the client is done with the current round - START = 1 # Used to signal by the server to start the current round - UPDATES = 2 # Used to send the updates from the server to the clients - - class FedAvgClient(BaseClient): - def __init__(self, config) -> None: - super().__init__(config) + def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None: + super().__init__(config, comm_utils) self.config = config - self.tag = CommProtocol - self.folder_deletion_signal = config["folder_deletion_signal_path"] - while not os.path.exists(self.folder_deletion_signal): - print("Existing experiment already present, waiting user input, enter 'r' or 'e'...") - time.sleep(5) - - # Once the signal file exists, read its contents - with open(self.folder_deletion_signal, "r") as signal_file: - mode = signal_file.read().strip() - - if mode == 'r' or mode == 'new': - try: - config['log_path'] = f"{config['log_path']}/client_{self.node_id}" - os.makedirs(config['log_path'], exist_ok=True) - except FileExistsError: - pass + try: + config['log_path'] = f"{config['log_path']}/node_{self.node_id}" + os.makedirs(config['log_path']) + except FileExistsError: + color_code = "\033[91m" # Red color + reset_code = "\033[0m" # Reset to default color + print(f"{color_code}Log directory for the node {self.node_id} already exists in {config['log_path']}") + print(f"Exiting to prevent accidental overwrite{reset_code}") + sys.exit(1) config['load_existing'] = False self.client_log_utils = LogUtils(config) - def local_train(self, round): + def local_train(self, round: int, **kwargs: Any): """ Train the model locally """ @@ -61,17 +45,17 @@ def local_train(self, round): self.client_log_utils.log_tb(f"train_loss/client{self.node_id}", avg_loss, round) self.client_log_utils.log_tb(f"train_accuracy/client{self.node_id}", avg_accuracy, round) - def local_test(self, **kwargs): + def local_test(self, **kwargs: Any): """ Test the model locally, not to be used in the traditional FedAvg """ pass - def get_representation(self) -> OrderedDict[str, Tensor]: + def get_representation(self, **kwargs: Any) -> OrderedDict[str, Tensor]: """ Share the model weights """ - return self.model.state_dict() + return self.model.state_dict() # type: ignore def set_representation(self, representation: OrderedDict[str, Tensor]): """ @@ -83,71 +67,79 @@ def run_protocol(self): start_epochs = self.config.get("start_epochs", 0) total_epochs = self.config["epochs"] for round in range(start_epochs, total_epochs): - self.client_log_utils.log_summary("Client {} waiting for semaphore from {}".format(self.node_id, self.server_node)) - self.comm_utils.wait_for_signal(src=self.server_node, tag=self.tag.START) - self.client_log_utils.log_summary("Client {} received semaphore from {}".format(self.node_id, self.server_node)) self.local_train(round) self.local_test() repr = self.get_representation() self.client_log_utils.log_summary("Client {} sending done signal to {}".format(self.node_id, self.server_node)) - self.comm_utils.send_signal( - dest=self.server_node, data=repr, tag=self.tag.DONE - ) + self.comm_utils.send(self.server_node, repr) self.client_log_utils.log_summary("Client {} waiting to get new model from {}".format(self.node_id, self.server_node)) - repr = self.comm_utils.wait_for_signal( - src=self.server_node, tag=self.tag.UPDATES - ) + repr = self.comm_utils.receive(self.server_node) self.client_log_utils.log_summary("Client {} received new model from {}".format(self.node_id, self.server_node)) self.set_representation(repr) - self.client_log_utils.log_summary("Round {} done for Client {}".format(round, self.node_id)) + # self.client_log_utils.log_summary("Round {} done for Client {}".format(round, self.node_id)) class FedAvgServer(BaseServer): - def __init__(self, config) -> None: - super().__init__(config) + def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None: + super().__init__(config, comm_utils) # self.set_parameters() self.config = config self.set_model_parameters(config) - self.tag = CommProtocol self.model_save_path = "{}/saved_models/node_{}.pt".format( self.config["results_path"], self.node_id ) self.folder_deletion_signal = config["folder_deletion_signal_path"] + # def fed_avg(self, model_wts: List[OrderedDict[str, Tensor]]): + # # All models are sampled currently at every round + # # Each model is assumed to have equal amount of data and hence + # # coeff is same for everyone + # num_users = len(model_wts) + # coeff = 1 / num_users # this assumes each node has equal amount of data + # avgd_wts: OrderedDict[str, Tensor] = OrderedDict() + # first_model = model_wts[0] + + # for node_num in range(num_users): + # local_wts = model_wts[node_num] + # for key in first_model.keys(): + # if node_num == 0: + # avgd_wts[key] = coeff * local_wts[key].to('cpu') + # else: + # avgd_wts[key] += coeff * local_wts[key].to('cpu') + # # put the model back to the device + # for key in avgd_wts.keys(): + # avgd_wts[key] = avgd_wts[key].to(self.device) + # return avgd_wts + def fed_avg(self, model_wts: List[OrderedDict[str, Tensor]]): - # All models are sampled currently at every round - # Each model is assumed to have equal amount of data and hence - # coeff is same for everyone num_users = len(model_wts) coeff = 1 / num_users - avgd_wts = OrderedDict() - first_model = model_wts[0] - - for client_num in range(num_users): - local_wts = model_wts[client_num] - for key in first_model.keys(): - if client_num == 0: - avgd_wts[key] = coeff * local_wts[key].to(self.device) - else: - avgd_wts[key] += coeff * local_wts[key].to(self.device) + avgd_wts: OrderedDict[str, Tensor] = OrderedDict() + + for key in model_wts[0].keys(): + avgd_wts[key] = sum(coeff * m[key] for m in model_wts) # type: ignore + + # Move to GPU only after averaging + for key in avgd_wts.keys(): + avgd_wts[key] = avgd_wts[key].to(self.device) return avgd_wts - def aggregate(self, representation_list: List[OrderedDict[str, Tensor]]): + def aggregate(self, representation_list: List[OrderedDict[str, Tensor]], **kwargs: Any) -> OrderedDict[str, Tensor]: """ Aggregate the model weights """ avg_wts = self.fed_avg(representation_list) return avg_wts - def set_representation(self, representation): + def set_representation(self, representation: OrderedDict[str, Tensor]): """ Set the model """ - for client_node in self.users: - self.comm_utils.send_signal(client_node, representation, self.tag.UPDATES) + self.comm_utils.broadcast(representation) + print("braodcasted") self.model.load_state_dict(representation) - def test(self) -> float: + def test(self, **kwargs: Any) -> List[float]: """ Test the model on the server """ @@ -162,22 +154,16 @@ def test(self) -> float: if test_acc > self.best_acc: self.best_acc = test_acc self.model_utils.save_model(self.model, self.model_save_path) - return test_loss, test_acc, time_taken + return [test_loss, test_acc, time_taken] def single_round(self): """ Runs the whole training procedure """ - for client_node in self.users: - self.log_utils.log_console( - "Server sending semaphore from {} to {}".format( - self.node_id, client_node - ) - ) - self.comm_utils.send_signal(dest=client_node, data=None, tag=self.tag.START) - self.log_utils.log_console("Server waiting for all clients to finish") - reprs = self.comm_utils.wait_for_all_clients(self.users, self.tag.DONE) - self.log_utils.log_console("Server received all clients done signal") + # calculate how much memory torch is occupying right now + # self.log_utils.log_console("Server waiting for all clients to finish") + reprs = self.comm_utils.all_gather() + # self.log_utils.log_console("Server received all clients done signal") avg_wts = self.aggregate(reprs) self.set_representation(avg_wts) #Remove the signal file after confirming that all client paths have been created @@ -185,7 +171,7 @@ def single_round(self): os.remove(self.folder_deletion_signal) def run_protocol(self): - self.log_utils.log_console("Starting iid clients federated averaging") + self.log_utils.log_console("Starting clients federated averaging") start_epochs = self.config.get("start_epochs", 0) total_epochs = self.config["epochs"] for round in range(start_epochs, total_epochs): @@ -197,6 +183,6 @@ def run_protocol(self): self.log_utils.log_tb(f"test_acc/clients", acc, round) self.log_utils.log_tb(f"test_loss/clients", loss, round) self.log_utils.log_console("Round: {} test_acc:{:.4f}, test_loss:{:.4f}, time taken {:.2f} seconds".format(round, acc, loss, time_taken)) - self.log_utils.log_summary("Round: {} test_acc:{:.4f}, test_loss:{:.4f}, time taken {:.2f} seconds".format(round, acc, loss, time_taken)) + # self.log_utils.log_summary("Round: {} test_acc:{:.4f}, test_loss:{:.4f}, time taken {:.2f} seconds".format(round, acc, loss, time_taken)) self.log_utils.log_console("Round {} complete".format(round)) self.log_utils.log_summary("Round {} complete".format(round,)) diff --git a/src/configs/algo_config.py b/src/configs/algo_config.py index 9f1044e..95ff516 100644 --- a/src/configs/algo_config.py +++ b/src/configs/algo_config.py @@ -22,7 +22,7 @@ "exp_type": "iid_clients_federated", # Learning setup "epochs": 1000, - "model": "resnet34", + "model": "resnet10", "model_lr": 3e-4, "batch_size": 256, "exp_keys": [], diff --git a/src/configs/sys_config.py b/src/configs/sys_config.py index 884bd13..bfe99fd 100644 --- a/src/configs/sys_config.py +++ b/src/configs/sys_config.py @@ -1,9 +1,15 @@ # System Configuration # TODO: Set up multiple non-iid configurations here. The goal of a separate system config # is to simulate different real-world scenarios without changing the algorithm configuration. -system_config = { - "num_users": 3, - "experiment_path": "./experiments/", +from typing import Dict, List + + +mpi_system_config = { + "comm": { + "type": "MPI" + }, + "num_users": 4, + # "experiment_path": "./experiments/", "dset": "cifar10", "dump_dir": "./expt_dump/", "dpath": "./datasets/imgs/cifar10/", @@ -37,4 +43,35 @@ "folder_deletion_signal_path":"./expt_dump/folder_deletion.signal" } -current_config = object_detect_system_config \ No newline at end of file +def get_device_ids(num_users: int, gpus_available: List[int]) -> Dict[str, List[int]]: + """ + Get the GPU device IDs for the users. + """ + # TODO: Make it multi-host + device_ids: Dict[str, List[int]] = {} + for i in range(num_users): + index = i % len(gpus_available) + gpu_id = gpus_available[index] + device_ids[f"node_{i}"] = [gpu_id] + return device_ids + +num_users = 10 +gpu_ids = [0, 1, 2, 3, 4, 5, 6, 7] +grpc_system_config = { + "num_users": num_users, + "comm": { + "type": "GRPC", + "peer_ids": ["localhost:50050"] # The super-node + }, + "dset": "cifar10", + "dump_dir": "./expt_dump/", + "dpath": "./datasets/imgs/cifar10/", + "seed": 2, + "device_ids": get_device_ids(num_users + 1, gpu_ids), # +1 for the super-node + "samples_per_user": 500, + "train_label_distribution": "iid", + "test_label_distribution": "iid", + "folder_deletion_signal_path":"./expt_dump/folder_deletion.signal" +} + +current_config = grpc_system_config diff --git a/src/main.py b/src/main.py index f4952fc..4880f22 100644 --- a/src/main.py +++ b/src/main.py @@ -11,6 +11,7 @@ B_DEFAULT = "./configs/algo_config.py" S_DEFAULT = "./configs/sys_config.py" +RANK_DEFAULT = None parser = argparse.ArgumentParser(description="Run collaborative learning experiments") parser.add_argument( @@ -27,11 +28,18 @@ type=str, help=f"filepath for system config, default: {S_DEFAULT}", ) +parser.add_argument( + "-r", + nargs="?", + default=RANK_DEFAULT, + type=int, + help=f"rank of the node, default: {RANK_DEFAULT}", +) args = parser.parse_args() scheduler = Scheduler() -scheduler.assign_config_by_path(args.s, args.b) +scheduler.assign_config_by_path(args.s, args.b, args.r) print("Config loaded") scheduler.install_config() diff --git a/src/main_grpc.py b/src/main_grpc.py new file mode 100644 index 0000000..546e6c0 --- /dev/null +++ b/src/main_grpc.py @@ -0,0 +1,37 @@ +""" +This module runs collaborative learning experiments using the Scheduler class. +""" + +import argparse +import logging +import subprocess + +from utils.config_utils import load_config + +logging.getLogger("PIL").setLevel(logging.INFO) + +S_DEFAULT = "./configs/sys_config.py" +RANK_DEFAULT = 0 + +parser = argparse.ArgumentParser(description="Run collaborative learning experiments") +parser.add_argument( + "-s", + nargs="?", + default=S_DEFAULT, + type=str, + help=f"filepath for system config, default: {S_DEFAULT}", +) + +args = parser.parse_args() + +sys_config = load_config(args.s) +print("Sys config loaded") + +# 1. find the number of users in the system configuration +# 2. start separate processes by running python main.py for each user + +num_users = sys_config["num_users"] + 1 # +1 for the super-node +for i in range(num_users): + print(f"Starting process for user {i}") + # start a Popen process + subprocess.Popen(["python", "main.py", "-r", str(i)]) \ No newline at end of file diff --git a/src/resnet_in.py b/src/resnet_in.py index b160e3e..dca143b 100644 --- a/src/resnet_in.py +++ b/src/resnet_in.py @@ -5,6 +5,7 @@ This module implements ResNet models for ImageNet classification. """ +from typing import Any from torch import nn from torch.hub import load_state_dict_from_url @@ -297,7 +298,7 @@ def _resnet( return model -def resnet18(pretrained=False, progress=True, **kwargs): +def resnet18(pretrained:bool=False, progress:bool=True, **kwargs: Any): r"""ResNet-18 model from `"Deep Residual Learning for Image Recognition" `_ Args: @@ -309,7 +310,7 @@ def resnet18(pretrained=False, progress=True, **kwargs): -def resnet34(pretrained=False, progress=True, **kwargs): +def resnet34(pretrained:bool=False, progress: bool=True, **kwargs: Any): r"""ResNet-34 model from `"Deep Residual Learning for Image Recognition" `_ Args: diff --git a/src/scheduler.py b/src/scheduler.py index c19e160..568a936 100644 --- a/src/scheduler.py +++ b/src/scheduler.py @@ -5,12 +5,10 @@ import os import random +from typing import Any, Dict -from mpi4py import MPI import torch -import numpy as np - import random import numpy @@ -31,11 +29,9 @@ from algos.fl_data_repr import FedDataRepClient, FedDataRepServer from algos.fl_val import FedValClient, FedValServer -from utils.log_utils import check_and_create_path +from utils.communication.comm_utils import CommunicationManager from utils.config_utils import load_config, process_config - from utils.log_utils import copy_source_code, check_and_create_path -from utils.config_utils import load_config, process_config, get_device_ids import os @@ -64,10 +60,10 @@ "fedval": [FedValServer, FedValClient], } -def get_node(config: dict, rank) -> BaseNode: +def get_node(config: Dict[str, Any], rank: int, comm_utils: CommunicationManager) -> BaseNode: algo_name = config["algo"] - return algo_map[algo_name][rank > 0](config) - + return algo_map[algo_name][rank > 0](config, comm_utils) + class Scheduler(): """ Manages the overall orchestration of experiments """ @@ -76,10 +72,12 @@ def __init__(self) -> None: pass def install_config(self) -> None: - self.config = process_config(self.config) + self.config: Dict[str, Any] = process_config(self.config) - def assign_config_by_path(self, sys_config_path, algo_config_path): + def assign_config_by_path(self, sys_config_path: Dict[str, Any], algo_config_path: Dict[str, Any], rank: int|None = None) -> None: self.sys_config = load_config(sys_config_path) + if rank is not None: + self.sys_config["comm"]["rank"] = rank self.algo_config = load_config(algo_config_path) self.merge_configs() @@ -88,19 +86,17 @@ def merge_configs(self): self.config.update(self.sys_config) self.config.update(self.algo_config) - def initialize(self, copy_souce_code=True) -> None: + def initialize(self, copy_souce_code: bool=True) -> None: assert self.config is not None, "Config should be set when initializing" - - comm = MPI.COMM_WORLD - rank = comm.Get_rank() + self.communication = CommunicationManager(self.config) # Base clients modify the seed later on seed = self.config["seed"] - torch.manual_seed(seed) + torch.manual_seed(seed) # type: ignore random.seed(seed) numpy.random.seed(seed) - if rank == 0: + if self.communication.get_rank() == 0: if copy_souce_code: copy_source_code(self.config) else: @@ -109,7 +105,7 @@ def initialize(self, copy_souce_code=True) -> None: os.mkdir(self.config["saved_models"]) os.mkdir(self.config["log_path"]) - self.node = get_node(self.config, rank=rank) + self.node = get_node(self.config, rank=self.communication.get_rank(), comm_utils=self.communication) def run_job(self) -> None: self.node.run_protocol() diff --git a/src/utils/communication/comm_utils.py b/src/utils/communication/comm_utils.py new file mode 100644 index 0000000..cb4a979 --- /dev/null +++ b/src/utils/communication/comm_utils.py @@ -0,0 +1,55 @@ +from enum import Enum +from typing import Any, Dict + +from utils.communication.grpc.main import GRPCCommunication +from utils.communication.mpi import MPICommUtils + + +class CommunicationType(Enum): + MPI = 1 + GRPC = 2 + HTTP = 3 + + +class CommunicationFactory: + @staticmethod + def create_communication(config: Dict[str, Any], comm_type: CommunicationType) -> Any: + comm_type = comm_type + if comm_type == CommunicationType.MPI: + return MPICommUtils(config) + elif comm_type == CommunicationType.GRPC: + return GRPCCommunication(config) + elif comm_type == CommunicationType.HTTP: + raise NotImplementedError("HTTP communication not yet implemented") + else: + raise ValueError("Invalid communication type") + + +class CommunicationManager: + def __init__(self, config: Dict[str, Any]): + self.comm_type = CommunicationType[config["comm"]["type"]] + self.comm = CommunicationFactory.create_communication(config, self.comm_type) + self.comm.initialize() + + def get_rank(self): + if self.comm_type == CommunicationType.MPI: + return self.comm.rank + elif self.comm_type == CommunicationType.GRPC: + return self.comm.rank + else: + raise NotImplementedError("Rank not implemented for communication type", self.comm_type) + + def send(self, dest:str|int, data:Any): + self.comm.send(dest, data) + + def receive(self, node_ids: str|int) -> Any: + return self.comm.receive(node_ids) + + def broadcast(self, data: Any): + self.comm.broadcast(data) + + def all_gather(self): + return self.comm.all_gather() + + def finalize(self): + self.comm.finalize() diff --git a/src/utils/communication/grpc/comm.proto b/src/utils/communication/grpc/comm.proto new file mode 100644 index 0000000..426b580 --- /dev/null +++ b/src/utils/communication/grpc/comm.proto @@ -0,0 +1,44 @@ +// To generate the gRPC code, run the following command: +// python -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. comm.proto --pyi_out=. +syntax = "proto3"; + +service CommunicationServer { + rpc send_data (Data) returns (Empty) {} + rpc get_rank (Empty) returns (Rank) {} + rpc update_port (PeerId) returns (Empty) {} + rpc send_peer_ids (PeerIds) returns (Empty) {} + rpc send_quorum (Quorum) returns (Empty) {} +} + +message Empty {} + +message Model { + bytes buffer = 1; +} + +message Data { + string id = 1; + Model model = 2; +} + +message Rank { + int32 rank = 1; +} + +message Port { + int32 port = 1; +} + +message PeerId { + Rank rank = 1; + Port port = 2; + string ip = 3; +} + +message PeerIds { + map peer_ids = 1; +} + +message Quorum { + bool quorum = 1; +} \ No newline at end of file diff --git a/src/utils/communication/grpc/comm_pb2.py b/src/utils/communication/grpc/comm_pb2.py new file mode 100644 index 0000000..5c1623f --- /dev/null +++ b/src/utils/communication/grpc/comm_pb2.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: comm.proto +# Protobuf Python Version: 5.26.1 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\ncomm.proto\"\x07\n\x05\x45mpty\"\x17\n\x05Model\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\")\n\x04\x44\x61ta\x12\n\n\x02id\x18\x01 \x01(\t\x12\x15\n\x05model\x18\x02 \x01(\x0b\x32\x06.Model\"\x14\n\x04Rank\x12\x0c\n\x04rank\x18\x01 \x01(\x05\"\x14\n\x04Port\x12\x0c\n\x04port\x18\x01 \x01(\x05\">\n\x06PeerId\x12\x13\n\x04rank\x18\x01 \x01(\x0b\x32\x05.Rank\x12\x13\n\x04port\x18\x02 \x01(\x0b\x32\x05.Port\x12\n\n\x02ip\x18\x03 \x01(\t\"k\n\x07PeerIds\x12\'\n\x08peer_ids\x18\x01 \x03(\x0b\x32\x15.PeerIds.PeerIdsEntry\x1a\x37\n\x0cPeerIdsEntry\x12\x0b\n\x03key\x18\x01 \x01(\x05\x12\x16\n\x05value\x18\x02 \x01(\x0b\x32\x07.PeerId:\x02\x38\x01\"\x18\n\x06Quorum\x12\x0e\n\x06quorum\x18\x01 \x01(\x08\x32\xb9\x01\n\x13\x43ommunicationServer\x12\x1c\n\tsend_data\x12\x05.Data\x1a\x06.Empty\"\x00\x12\x1b\n\x08get_rank\x12\x06.Empty\x1a\x05.Rank\"\x00\x12 \n\x0bupdate_port\x12\x07.PeerId\x1a\x06.Empty\"\x00\x12#\n\rsend_peer_ids\x12\x08.PeerIds\x1a\x06.Empty\"\x00\x12 \n\x0bsend_quorum\x12\x07.Quorum\x1a\x06.Empty\"\x00\x62\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'comm_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + DESCRIPTOR._loaded_options = None + _globals['_PEERIDS_PEERIDSENTRY']._loaded_options = None + _globals['_PEERIDS_PEERIDSENTRY']._serialized_options = b'8\001' + _globals['_EMPTY']._serialized_start=14 + _globals['_EMPTY']._serialized_end=21 + _globals['_MODEL']._serialized_start=23 + _globals['_MODEL']._serialized_end=46 + _globals['_DATA']._serialized_start=48 + _globals['_DATA']._serialized_end=89 + _globals['_RANK']._serialized_start=91 + _globals['_RANK']._serialized_end=111 + _globals['_PORT']._serialized_start=113 + _globals['_PORT']._serialized_end=133 + _globals['_PEERID']._serialized_start=135 + _globals['_PEERID']._serialized_end=197 + _globals['_PEERIDS']._serialized_start=199 + _globals['_PEERIDS']._serialized_end=306 + _globals['_PEERIDS_PEERIDSENTRY']._serialized_start=251 + _globals['_PEERIDS_PEERIDSENTRY']._serialized_end=306 + _globals['_QUORUM']._serialized_start=308 + _globals['_QUORUM']._serialized_end=332 + _globals['_COMMUNICATIONSERVER']._serialized_start=335 + _globals['_COMMUNICATIONSERVER']._serialized_end=520 +# @@protoc_insertion_point(module_scope) diff --git a/src/utils/communication/grpc/comm_pb2.pyi b/src/utils/communication/grpc/comm_pb2.pyi new file mode 100644 index 0000000..e81bc4e --- /dev/null +++ b/src/utils/communication/grpc/comm_pb2.pyi @@ -0,0 +1,65 @@ +from google.protobuf.internal import containers as _containers +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from typing import ClassVar as _ClassVar, Mapping as _Mapping, Optional as _Optional, Union as _Union + +DESCRIPTOR: _descriptor.FileDescriptor + +class Empty(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class Model(_message.Message): + __slots__ = ("buffer",) + BUFFER_FIELD_NUMBER: _ClassVar[int] + buffer: bytes + def __init__(self, buffer: _Optional[bytes] = ...) -> None: ... + +class Data(_message.Message): + __slots__ = ("id", "model") + ID_FIELD_NUMBER: _ClassVar[int] + MODEL_FIELD_NUMBER: _ClassVar[int] + id: str + model: Model + def __init__(self, id: _Optional[str] = ..., model: _Optional[_Union[Model, _Mapping]] = ...) -> None: ... + +class Rank(_message.Message): + __slots__ = ("rank",) + RANK_FIELD_NUMBER: _ClassVar[int] + rank: int + def __init__(self, rank: _Optional[int] = ...) -> None: ... + +class Port(_message.Message): + __slots__ = ("port",) + PORT_FIELD_NUMBER: _ClassVar[int] + port: int + def __init__(self, port: _Optional[int] = ...) -> None: ... + +class PeerId(_message.Message): + __slots__ = ("rank", "port", "ip") + RANK_FIELD_NUMBER: _ClassVar[int] + PORT_FIELD_NUMBER: _ClassVar[int] + IP_FIELD_NUMBER: _ClassVar[int] + rank: Rank + port: Port + ip: str + def __init__(self, rank: _Optional[_Union[Rank, _Mapping]] = ..., port: _Optional[_Union[Port, _Mapping]] = ..., ip: _Optional[str] = ...) -> None: ... + +class PeerIds(_message.Message): + __slots__ = ("peer_ids",) + class PeerIdsEntry(_message.Message): + __slots__ = ("key", "value") + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: int + value: PeerId + def __init__(self, key: _Optional[int] = ..., value: _Optional[_Union[PeerId, _Mapping]] = ...) -> None: ... + PEER_IDS_FIELD_NUMBER: _ClassVar[int] + peer_ids: _containers.MessageMap[int, PeerId] + def __init__(self, peer_ids: _Optional[_Mapping[int, PeerId]] = ...) -> None: ... + +class Quorum(_message.Message): + __slots__ = ("quorum",) + QUORUM_FIELD_NUMBER: _ClassVar[int] + quorum: bool + def __init__(self, quorum: bool = ...) -> None: ... diff --git a/src/utils/communication/grpc/comm_pb2_grpc.py b/src/utils/communication/grpc/comm_pb2_grpc.py new file mode 100644 index 0000000..feb7eb8 --- /dev/null +++ b/src/utils/communication/grpc/comm_pb2_grpc.py @@ -0,0 +1,274 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc +import warnings + +import comm_pb2 as comm__pb2 + +GRPC_GENERATED_VERSION = '1.64.0' +GRPC_VERSION = grpc.__version__ +EXPECTED_ERROR_RELEASE = '1.65.0' +SCHEDULED_RELEASE_DATE = 'June 25, 2024' +_version_not_supported = False + +try: + from grpc._utilities import first_version_is_lower + _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) +except ImportError: + _version_not_supported = True + +if _version_not_supported: + warnings.warn( + f'The grpc package installed is at version {GRPC_VERSION},' + + f' but the generated code in comm_pb2_grpc.py depends on' + + f' grpcio>={GRPC_GENERATED_VERSION}.' + + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' + + f' This warning will become an error in {EXPECTED_ERROR_RELEASE},' + + f' scheduled for release on {SCHEDULED_RELEASE_DATE}.', + RuntimeWarning + ) + + +class CommunicationServerStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.send_data = channel.unary_unary( + '/CommunicationServer/send_data', + request_serializer=comm__pb2.Data.SerializeToString, + response_deserializer=comm__pb2.Empty.FromString, + _registered_method=True) + self.get_rank = channel.unary_unary( + '/CommunicationServer/get_rank', + request_serializer=comm__pb2.Empty.SerializeToString, + response_deserializer=comm__pb2.Rank.FromString, + _registered_method=True) + self.update_port = channel.unary_unary( + '/CommunicationServer/update_port', + request_serializer=comm__pb2.PeerId.SerializeToString, + response_deserializer=comm__pb2.Empty.FromString, + _registered_method=True) + self.send_peer_ids = channel.unary_unary( + '/CommunicationServer/send_peer_ids', + request_serializer=comm__pb2.PeerIds.SerializeToString, + response_deserializer=comm__pb2.Empty.FromString, + _registered_method=True) + self.send_quorum = channel.unary_unary( + '/CommunicationServer/send_quorum', + request_serializer=comm__pb2.Quorum.SerializeToString, + response_deserializer=comm__pb2.Empty.FromString, + _registered_method=True) + + +class CommunicationServerServicer(object): + """Missing associated documentation comment in .proto file.""" + + def send_data(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def get_rank(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def update_port(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def send_peer_ids(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def send_quorum(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_CommunicationServerServicer_to_server(servicer, server): + rpc_method_handlers = { + 'send_data': grpc.unary_unary_rpc_method_handler( + servicer.send_data, + request_deserializer=comm__pb2.Data.FromString, + response_serializer=comm__pb2.Empty.SerializeToString, + ), + 'get_rank': grpc.unary_unary_rpc_method_handler( + servicer.get_rank, + request_deserializer=comm__pb2.Empty.FromString, + response_serializer=comm__pb2.Rank.SerializeToString, + ), + 'update_port': grpc.unary_unary_rpc_method_handler( + servicer.update_port, + request_deserializer=comm__pb2.PeerId.FromString, + response_serializer=comm__pb2.Empty.SerializeToString, + ), + 'send_peer_ids': grpc.unary_unary_rpc_method_handler( + servicer.send_peer_ids, + request_deserializer=comm__pb2.PeerIds.FromString, + response_serializer=comm__pb2.Empty.SerializeToString, + ), + 'send_quorum': grpc.unary_unary_rpc_method_handler( + servicer.send_quorum, + request_deserializer=comm__pb2.Quorum.FromString, + response_serializer=comm__pb2.Empty.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'CommunicationServer', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers('CommunicationServer', rpc_method_handlers) + + + # This class is part of an EXPERIMENTAL API. +class CommunicationServer(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def send_data(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/CommunicationServer/send_data', + comm__pb2.Data.SerializeToString, + comm__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def get_rank(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/CommunicationServer/get_rank', + comm__pb2.Empty.SerializeToString, + comm__pb2.Rank.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def update_port(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/CommunicationServer/update_port', + comm__pb2.PeerId.SerializeToString, + comm__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def send_peer_ids(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/CommunicationServer/send_peer_ids', + comm__pb2.PeerIds.SerializeToString, + comm__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def send_quorum(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/CommunicationServer/send_quorum', + comm__pb2.Quorum.SerializeToString, + comm__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) diff --git a/src/utils/communication/grpc/grpc_utils.py b/src/utils/communication/grpc/grpc_utils.py new file mode 100644 index 0000000..64d8e6b --- /dev/null +++ b/src/utils/communication/grpc/grpc_utils.py @@ -0,0 +1,18 @@ +from collections import OrderedDict +import io +import torch + +def serialize_model(state_dict: OrderedDict[str, torch.Tensor]) -> bytes: + # put every parameter on cpu first + for key in state_dict.keys(): + state_dict[key] = state_dict[key].to('cpu') + buffer = io.BytesIO() + torch.save(state_dict, buffer) # type: ignore + buffer.seek(0) + return buffer.read() + +def deserialize_model(model_bytes: bytes) -> OrderedDict[str, torch.Tensor]: + buffer = io.BytesIO(model_bytes) + buffer.seek(0) + model_wts = torch.load(buffer) # type: ignore + return model_wts diff --git a/src/utils/communication/grpc/main.py b/src/utils/communication/grpc/main.py new file mode 100644 index 0000000..82695b1 --- /dev/null +++ b/src/utils/communication/grpc/main.py @@ -0,0 +1,282 @@ +from concurrent import futures +from queue import Queue +import random +import re +import threading +import time +import socket +from typing import Any, Dict, List, OrderedDict, Union +from urllib.parse import unquote +import grpc # type: ignore +from utils.communication.grpc.grpc_utils import deserialize_model, serialize_model +import os +import sys + +grpc_generated_dir = os.path.dirname(os.path.abspath(__file__)) +if grpc_generated_dir not in sys.path: + sys.path.append(grpc_generated_dir) + +import comm_pb2 as comm_pb2 +import comm_pb2_grpc as comm_pb2_grpc +from utils.communication.interface import CommunicationInterface + + +# TODO: Several changes needed to improve the quality of the code +# 1. We need to improve comm.proto and get rid of singletons like Rank, Port etc. +# 2. Some parts of the code are heavily nested and need to be refactored +# 3. Insert try-except blocks wherever communication is involved +# 4. Probably a good idea to move the Servicer class to a separate file +# 5. Not needed for benchmarking but for the system to be robust, we need to implement timeouts and fault tolerance +# 6. Peer_ids should be indexed by a unique identifier +# 7. Try to get rid of type: ignore as much as possible + +def is_port_available(port: int) -> bool: + """ + Check if a port is available for use. + """ + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: # type: ignore + return s.connect_ex(('localhost', port)) != 0 # type: ignore + +def get_port(rank: int, num_users: int) -> int: + """ + Get the port number for the given rank. + """ + start = 50051 + while True: + port = start + rank + if is_port_available(port): + return port + # if we increment by 1 then it's likely that + # the next node will also have the same port number + start += num_users + if start > 65535: + raise Exception(f"No available ports for node {rank}") + +def parse_peer_address(peer: str) -> str: + # Remove 'ipv4:' or 'ipv6:' prefix + if peer.startswith(('ipv4:', 'ipv6:')): + peer = peer.split(':', 1)[1] + + # Handle IPv6 address + if peer.startswith('['): + # Extract IPv6 address + match = re.match(r'\[([^\]]+)\]', peer) + if match: + return unquote(match.group(1)) # Decode URL-encoded characters + else: + # Handle IPv4 address or hostname + return peer.rsplit(':', 1)[0] # Remove port number + return "" + +class Servicer(comm_pb2_grpc.CommunicationServerServicer): + def __init__(self, super_node_host: str): + self.lock = threading.Lock() + self.condition = threading.Condition(self.lock) + self.received_data: Queue[Any] = Queue() + self.quorum: Queue[bool] = Queue() + port = int(super_node_host.split(":")[1]) + ip = super_node_host.split(":")[0] + self.peer_ids: OrderedDict[int, Dict[str, int|str]] = OrderedDict({ + 0: {"rank": 0, "port": port, "ip": ip} + }) + + def send_data(self, request, context) -> comm_pb2.Empty: # type: ignore + self.received_data.put(deserialize_model(request.model.buffer)) # type: ignore + return comm_pb2.Empty() # type: ignore + + def get_rank(self, request: comm_pb2.Empty, context: grpc.ServicerContext) -> comm_pb2.Rank: # type: ignore + try: + with self.lock: + peer = context.peer() # type: ignore + # parse the hostname from peer + peer_str = parse_peer_address(peer) # type: ignore + rank = len(self.peer_ids) + # TODO: index the peer_ids by a unique identifier + self.peer_ids[rank] = {"rank": rank, "port": 0, "ip": peer_str} + rank = self.peer_ids[rank].get("rank", -1) # Default to -1 if not found + return comm_pb2.Rank(rank=rank) # type: ignore + except Exception as e: + context.abort(grpc.StatusCode.INTERNAL, f"Error in get_rank: {str(e)}") # type: ignore + + def update_port(self, request: comm_pb2.PeerIds, context: grpc.ServicerContext) -> comm_pb2.Empty: + with self.lock: + self.peer_ids[request.rank.rank]["port"] = request.port.port # type: ignore + return comm_pb2.Empty() # type: ignore + + def send_peer_ids(self, request, context) -> comm_pb2.Empty: # type: ignore + """ + Used by the super node to update all peers with the peer_ids + after achieving quorum. + """ + peer_ids: comm_pb2.PeerIds = request.peer_ids # type: ignore + for rank in peer_ids: # type: ignore + peer_id_proto = peer_ids[rank] # type: ignore + peer_id_dict: Dict[str, Union[int, str]] = { + "rank": peer_id_proto.rank.rank, # type: ignore + "port": peer_id_proto.port.port, # type: ignore + "ip": peer_id_proto.ip # type: ignore + } + self.peer_ids[rank] = peer_id_dict + return comm_pb2.Empty() + + def send_quorum(self, request, context) -> comm_pb2.Empty: # type: ignore + self.quorum.put(request.quorum) # type: ignore + return comm_pb2.Empty() # type: ignore + + +class GRPCCommunication(CommunicationInterface): + def __init__(self, config: Dict[str, Dict[str, Any]]): + # TODO: Implement this differently later by creating a super node + # that maintains a list of all peers + # all peers will send their IDs to this super node + # when they start. + # The implementation will have + # 1. Registration phase where every peer registers itself and gets a rank + # 2. Once a threshold number of peers have registered, the super node sets quorum to True + # 3. The super node broadcasts the peer_ids to all peers + # 4. The nodes will execute rest of the protocol in the same way as before + self.num_users: int = int(config["num_users"]) # type: ignore + self.rank: int|None = config["comm"]["rank"] + self.super_node_host: str = config["comm"]["peer_ids"][0] + if self.rank == 0: + node_id: List[str] = self.super_node_host.split(":") + self.host: str = node_id[0] + self.port: int = int(node_id[1]) + self.listener: Any = None + self.servicer = Servicer(self.super_node_host) + + @staticmethod + def get_registered_users(peer_ids: OrderedDict[int, Dict[str, int|str]]) -> int: + # count the number of entries that have a non-zero port + return len([peer_id for peer_id, values in peer_ids.items() if values.get("port") != 0]) + + def register(self): + with grpc.insecure_channel(self.super_node_host) as channel: # type: ignore + stub = comm_pb2_grpc.CommunicationServerStub(channel) + max_tries = 10 + while max_tries > 0: + try: + self.rank = stub.get_rank(comm_pb2.Empty()).rank # type: ignore + break + except grpc.RpcError as e: + print(f"RPC failed: {e}", "Retrying...") + # sleep for a random time between 1 and 10 seconds + random_time = random.randint(1, 10) + time.sleep(random_time) + max_tries -= 1 + self.port = get_port(self.rank, self.num_users) # type: ignore because we are setting it in the register method + rank = comm_pb2.Rank(rank=self.rank) # type: ignore + port = comm_pb2.Port(port=self.port) + peer_id = comm_pb2.PeerId(rank=rank, port=port) + stub.update_port(peer_id) # type: ignore + + def start_listener(self): + self.listener = grpc.server(futures.ThreadPoolExecutor(max_workers=4), options=[ # type: ignore + ('grpc.max_send_message_length', 100 * 1024 * 1024), # 100MB + ('grpc.max_receive_message_length', 100 * 1024 * 1024), # 100MB + ]) + comm_pb2_grpc.add_CommunicationServerServicer_to_server(self.servicer, self.listener) # type: ignore + self.listener.add_insecure_port(f'[::]:{self.port}') + self.listener.start() + print(f'Started listener on port {self.port}') + + def peer_ids_to_proto(self, peer_ids: OrderedDict[int, Dict[str, int|str]]) -> Dict[int, comm_pb2.PeerId]: + peer_ids_proto: Dict[int, comm_pb2.PeerId] = {} + for peer_id in peer_ids: + rank = comm_pb2.Rank(rank=peer_ids[peer_id].get("rank")) # type: ignore + port = comm_pb2.Port(port=peer_ids[peer_id].get("port")) # type: ignore + ip = str(peer_ids[peer_id].get("ip")) + peer_ids_proto[peer_id] = comm_pb2.PeerId(rank=rank, port=port, ip=ip) + return peer_ids_proto + + def initialize(self): + if self.rank != 0: + self.register() + + self.start_listener() + + # wait for the quorum to be set + if self.rank != 0: + status = self.servicer.quorum.get() + if not status: + print("Quorum became false!") + sys.exit(1) + else: + quorum_threshold = self.num_users + 1 # +1 for the super node + num_registered = self.get_registered_users(self.servicer.peer_ids) + while num_registered < quorum_threshold: + # sleep for 5 seconds + print(f"Waiting for {quorum_threshold} users to register, {num_registered} have registered so far") + time.sleep(5) + num_registered = self.get_registered_users(self.servicer.peer_ids) + # TODO: Implement a timeout here and if the timeout is reached + # then set quorum to False for all registered users + # and exit the program + + print("All users have registered", self.servicer.peer_ids) + for peer_id in self.servicer.peer_ids: + host_ip = self.servicer.peer_ids[peer_id].get("ip") + own_ip = self.servicer.peer_ids[0].get("ip") + if host_ip != own_ip: + port = self.servicer.peer_ids[peer_id].get("port") + address = f"{host_ip}:{port}" + with grpc.insecure_channel(address) as channel: # type: ignore + stub = comm_pb2_grpc.CommunicationServerStub(channel) + proto_msg = comm_pb2.PeerIds(peer_ids=self.peer_ids_to_proto(self.servicer.peer_ids)) + stub.send_peer_ids(proto_msg) # type: ignore + stub.send_quorum(comm_pb2.Quorum(quorum=True)) # type: ignore + + def get_host_from_rank(self, rank: int) -> str: + for peer_id in self.servicer.peer_ids: + if self.servicer.peer_ids[peer_id].get("rank") == rank: + return self.servicer.peer_ids[peer_id].get("ip") + ":" + str(self.servicer.peer_ids[peer_id].get("port")) # type: ignore + raise Exception(f"Rank {rank} not found in peer_ids") + + def send(self, dest: str|int, data: OrderedDict[str, Any]): + """ + data should be a torch model + """ + dest_host: str = "" + if type(dest) == int: + dest_host = self.get_host_from_rank(dest) + else: + dest_host = str(dest) + try: + buffer = serialize_model(data) + with grpc.insecure_channel(dest_host) as channel: # type: ignore + stub = comm_pb2_grpc.CommunicationServerStub(channel) # type: ignore + model = comm_pb2.Model(buffer=buffer) # type: ignore + stub.send_data(comm_pb2.Data(model=model, id=str(self.rank))) # type: ignore + except grpc.RpcError as e: + print(f"RPC failed: {e}") + sys.exit(1) + + def receive(self, node_ids: str|int) -> Any: + # this .get() will block until + # at least 1 item is received in the queue + return self.servicer.received_data.get() + + def is_own_id(self, peer_id: int) -> bool: + rank = self.servicer.peer_ids[peer_id].get("rank") + if rank != self.rank: + return False + return True + + def broadcast(self, data: Any): + for peer_id in self.servicer.peer_ids: + if not self.is_own_id(peer_id): + self.send(peer_id, data) + + def all_gather(self) -> Any: + # this will block until all items are received + # from all peers + items: List[Any] = [] + for peer_id in self.servicer.peer_ids: + if not self.is_own_id(peer_id): + items.append(self.receive(peer_id)) + return items + + def finalize(self): + if self.listener: + self.listener.stop(0) + print(f'Stopped server on port {self.port}') diff --git a/src/utils/communication/interface.py b/src/utils/communication/interface.py new file mode 100644 index 0000000..b5ab2ba --- /dev/null +++ b/src/utils/communication/interface.py @@ -0,0 +1,30 @@ +from abc import ABC, abstractmethod +from typing import Any + +class CommunicationInterface(ABC): + def __init__(self): + pass + + @abstractmethod + def initialize(self): + pass + + @abstractmethod + def send(self, dest: str|int, data: Any): + pass + + @abstractmethod + def receive(self, node_ids: str|int) -> Any: + pass + + @abstractmethod + def broadcast(self, data: Any): + pass + + @abstractmethod + def all_gather(self) -> Any: + pass + + @abstractmethod + def finalize(self): + pass diff --git a/src/utils/communication/mpi.py b/src/utils/communication/mpi.py new file mode 100644 index 0000000..ea5bad8 --- /dev/null +++ b/src/utils/communication/mpi.py @@ -0,0 +1,35 @@ +from typing import Dict, Any, List +from mpi4py import MPI +from utils.communication.interface import CommunicationInterface + +class MPICommUtils(CommunicationInterface): + def __init__(self, config: Dict[str, Dict[str, Any]]): + self.comm = MPI.COMM_WORLD + self.rank = self.comm.Get_rank() + self.size = self.comm.Get_size() + + def initialize(self): + pass + + def send(self, dest: str|int, data: Any): + self.comm.send(data, dest=int(dest)) + + def receive(self, node_ids: str|int) -> Any: + return self.comm.recv(source=int(node_ids)) + + def broadcast(self, data: Any): + for i in range(1, self.size): + if i != self.rank: + self.send(i, data) + + def all_gather(self): + """ + This function is used to gather data from all the nodes. + """ + items: List[Any] = [] + for i in range(1, self.size): + items.append(self.receive(i)) + return items + + def finalize(self): + pass diff --git a/src/utils/config_utils.py b/src/utils/config_utils.py index b3cbc8a..51407d9 100644 --- a/src/utils/config_utils.py +++ b/src/utils/config_utils.py @@ -1,8 +1,9 @@ +from typing import Any, Dict import jmespath import importlib -def load_config(config_path): +def load_config(config_path: str) -> Dict[str, Any]: path = ".".join(config_path.split(".")[1].split("/")[1:]) config = importlib.import_module(path).current_config return config diff --git a/src/utils/data_utils.py b/src/utils/data_utils.py index d3ea63b..4fbe2ef 100644 --- a/src/utils/data_utils.py +++ b/src/utils/data_utils.py @@ -1,4 +1,5 @@ import importlib +from typing import Tuple import numpy as np import torch import torchvision.transforms as T @@ -93,7 +94,7 @@ def filter_by_class(dataset, classes): return Subset(dataset, indices), indices -def random_samples(dataset, num_samples): +def random_samples(dataset, num_samples) -> Tuple[Subset, np.ndarray]: """ Returns a random subset of samples from the dataset. """ @@ -110,7 +111,7 @@ def extr_noniid(train_dataset, samples_per_client, classes): return Subset(all_data, perm[:samples_per_client]) -def cifar_extr_noniid( +def cifar_extr_noniid( train_dataset, test_dataset, num_users, n_class, num_samples, rate_unbalance ): """ diff --git a/src/utils/log_utils.py b/src/utils/log_utils.py index 3dd6612..96afea2 100644 --- a/src/utils/log_utils.py +++ b/src/utils/log_utils.py @@ -9,11 +9,11 @@ import sys from glob import glob from shutil import copytree, copy2 -from PIL import Image +from typing import Any, Dict import torch -import torchvision.transforms as T -from torchvision.utils import make_grid, save_image -from tensorboardX import SummaryWriter +import torchvision.transforms as T # type: ignore +from torchvision.utils import make_grid, save_image # type: ignore +from tensorboardX import SummaryWriter # type: ignore import numpy as np @@ -36,7 +36,7 @@ def deprocess(img): return img.type(torch.uint8) -def check_and_create_path(path, folder_deletion_path): +def check_and_create_path(path: str, folder_deletion_path: str|None=None): """ Checks if the specified path exists and prompts the user for action if it does. Creates the directory if it does not exist. @@ -45,26 +45,11 @@ def check_and_create_path(path, folder_deletion_path): path (str): Path to check and create if necessary. """ if os.path.isdir(path): - print(f"Experiment in {path} already present") - done = False - while not done: - # Color the input prompt - color_code = "\033[94m" # Blue text - reset_code = "\033[0m" # Reset to default color - # Highlighted prompt in blue - inp = input(f"{color_code}Press e to exit, r to replace it: {reset_code}") - - if inp == "e": - sys.exit() - elif inp == "r": - done = True - shutil.rmtree(path) - os.makedirs(path) - with open(folder_deletion_path, "w") as signal_file: - #Folder deletion complete signal. - signal_file.write("r") - else: - print("Input not understood") + color_code = "\033[94m" # Blue text + reset_code = "\033[0m" # Reset to default color + print(f"{color_code}Experiment in {path} already present. Exiting.") + print(f"Please do: rm -rf {path} to delete the folder.{reset_code}") + sys.exit() else: os.makedirs(path) with open(folder_deletion_path, "w") as signal_file: @@ -72,7 +57,7 @@ def check_and_create_path(path, folder_deletion_path): signal_file.write("new") -def copy_source_code(config: dict) -> None: +def copy_source_code(config: Dict[str, Any]) -> None: """ Copy source code to experiment folder for reproducibility. @@ -120,7 +105,7 @@ class LogUtils: """ Utility class for logging and saving experiment data. """ - def __init__(self, config) -> None: + def __init__(self, config: Dict[str, Any]) -> None: log_dir = config["log_path"] load_existing = config["load_existing"] log_format = ( @@ -144,7 +129,7 @@ def init_summary(self): """ self.summary_file = open(f"{self.log_dir}/summary.txt", "w", encoding="utf-8") - def init_tb(self, load_existing): + def init_tb(self, load_existing: bool): """ Initialize TensorBoard logging. @@ -164,7 +149,7 @@ def init_npy(self): if not os.path.exists(npy_path) or not os.path.isdir(npy_path): os.makedirs(npy_path) - def log_summary(self, text): + def log_summary(self, text: str): """ Add summary text to the summary file for logging. """ @@ -174,7 +159,7 @@ def log_summary(self, text): else: raise ValueError("Summary file is not initialized. Call init_summary() first.") - def log_image(self, imgs: torch.Tensor, key, iteration): + def log_image(self, imgs: torch.Tensor, key: str, iteration: int): """ Log image to file and TensorBoard. @@ -187,7 +172,7 @@ def log_image(self, imgs: torch.Tensor, key, iteration): save_image(grid_img, f"{self.log_dir}/{iteration}_{key}.png") self.writer.add_image(key, grid_img.numpy(), iteration) - def log_console(self, msg): + def log_console(self, msg: str): """ Log a message to the console. @@ -196,7 +181,7 @@ def log_console(self, msg): """ logging.info(msg) - def log_tb(self, key, value, iteration): + def log_tb(self, key: str, value: float|int, iteration: int): """ Log a scalar value to TensorBoard. @@ -205,9 +190,9 @@ def log_tb(self, key, value, iteration): value (float): Value to log. iteration (int): Current iteration number. """ - self.writer.add_scalar(key, value, iteration) + self.writer.add_scalar(key, value, iteration) # type: ignore - def log_npy(self, key, value): + def log_npy(self, key: str, value: np.ndarray): """ Save a numpy array to file. @@ -217,7 +202,7 @@ def log_npy(self, key, value): """ np.save(f"{self.log_dir}/npy/{key}.npy", value) - def log_max_stats_per_client(self, stats_per_client, round_step, metric): + def log_max_stats_per_client(self, stats_per_client: np.ndarray, round_step: int, metric: str): """ Log maximum statistics per client. diff --git a/src/utils/model_utils.py b/src/utils/model_utils.py index 8441872..14ea93b 100644 --- a/src/utils/model_utils.py +++ b/src/utils/model_utils.py @@ -1,9 +1,10 @@ from collections import OrderedDict -from typing import Tuple, List +from typing import Any, Tuple, List import torch from torch import Tensor import torch.nn as nn import torch.nn.functional as F +from torch.utils.data import DataLoader from torch.nn.parallel import DataParallel import resnet @@ -13,7 +14,8 @@ class ModelUtils(): - def __init__(self) -> None: + def __init__(self, device: torch.device) -> None: + self.device = device self.models_layers_idx = { "resnet10": { @@ -33,14 +35,14 @@ def __init__(self) -> None: } def get_model( - self, - model_name: str, - dset: str, - device: torch.device, - device_ids: list, - pretrained=False, - **kwargs) -> DataParallel: - self.dset = dset # cifar10, wilds, pascal + self, + model_name: str, + dset: str, + device: torch.device, + device_ids: List[int], + pretrained:bool=False, + **kwargs: Any + ) -> nn.Module: # TODO: add support for loading checkpointed models model_name = model_name.lower() if model_name == "resnet10": @@ -65,12 +67,12 @@ def get_model( def train(self, model: nn.Module, - optim, - dloader, - loss_fn, + optim: torch.optim.Optimizer, + dloader: DataLoader[Any], + loss_fn: Any, device: torch.device, - test_loader=None, - **kwargs) -> Tuple[float, + test_loader: DataLoader[Any]|None=None, + **kwargs: Any) -> Tuple[float, float]: """TODO: generate docstring """ @@ -291,8 +293,8 @@ def deep_mutual_train(self, models, optim, dloader, train_accuracy[i] = 100. * correct[i] / len(dloader.dataset) return train_loss, train_accuracy - def test(self, model, dloader, loss_fn, device, - **kwargs) -> Tuple[float, float]: + def test(self, model: nn.Module, dloader: DataLoader[Any], loss_fn: Any, device: torch.device, + **kwargs: Any) -> Tuple[float, float]: """TODO: generate docstring """ model.eval() @@ -362,7 +364,7 @@ def test_classification(self, model, dloader, loss_fn, device, **kwargs) -> Tupl acc = correct / len(dloader.dataset) return test_loss, acc - def save_model(self, model, path): + def save_model(self, model: nn.Module, path: str) -> None: if isinstance(model, DataParallel): model_ = model.module else: @@ -403,3 +405,9 @@ def filter_model_weights( if key not in key_to_ignore: filtered_model_wts[key] = param return filtered_model_wts + + def get_memory_usage(self): + """ + Get the memory usage + """ + return torch.cuda.memory_allocated(self.device)