diff --git a/src/algos/base_class.py b/src/algos/base_class.py index 6c4eeeb..ff8a53f 100644 --- a/src/algos/base_class.py +++ b/src/algos/base_class.py @@ -4,7 +4,7 @@ from torch.utils.data import DataLoader, Subset from collections import OrderedDict -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple from torch import Tensor import copy import random @@ -28,11 +28,11 @@ get_dset_balanced_communities, get_dset_communities, ) -import torchvision.transforms as T +import torchvision.transforms as T # type: ignore import os class BaseNode(ABC): - def __init__(self, config, comm_utils: CommunicationManager) -> None: + def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None: self.comm_utils = comm_utils self.node_id = self.comm_utils.get_rank() @@ -52,13 +52,13 @@ def __init__(self, config, comm_utils: CommunicationManager) -> None: if isinstance(config["dset"], dict): if self.node_id != 0: config["dset"].pop("0") - self.dset = config["dset"][str(self.node_id)] + self.dset = str(config["dset"][str(self.node_id)]) config["dpath"] = config["dpath"][self.dset] else: self.dset = config["dset"] self.setup_cuda(config) - self.model_utils = ModelUtils() + self.model_utils = ModelUtils(self.device) self.dset_obj = get_dataset(self.dset, dpath=config["dpath"]) self.set_constants() @@ -66,7 +66,7 @@ def __init__(self, config, comm_utils: CommunicationManager) -> None: def set_constants(self): self.best_acc = 0.0 - def setup_cuda(self, config): + def setup_cuda(self, config: Dict[str, Any]): # Need a mapping from rank to device id device_ids_map = config["device_ids"] node_name = "node_{}".format(self.node_id) @@ -80,7 +80,7 @@ def setup_cuda(self, config): self.device = torch.device("cpu") print("Using CPU") - def set_model_parameters(self, config): + def set_model_parameters(self, config: Dict[str, Any]): # Model related parameters optim_name = config.get("optimizer", "adam") if optim_name == "adam": @@ -144,7 +144,7 @@ def set_shared_exp_parameters(self, config): self.log_utils.log_console("Communities: {}".format(self.communities)) @abstractmethod - def run_protocol(self): + def run_protocol(self) -> None: raise NotImplementedError @@ -368,13 +368,13 @@ def local_train(self, round: int, **kwargs: Any) -> None: """ raise NotImplementedError - def local_test(self, dataset, **kwargs): + def local_test(self, **kwargs: Any) -> float | Tuple[float, float] | None: """ Test the model locally """ raise NotImplementedError - def get_representation(self, **kwargs): + def get_representation(self, **kwargs: Any) -> OrderedDict[str, Tensor] | List[Tensor] | Tensor: """ Share the model representation """ @@ -429,13 +429,13 @@ def set_data_parameters(self, config): batch_size = config["batch_size"] self._test_loader = DataLoader(test_dset, batch_size=batch_size) - def aggregate(self, representation_list, **kwargs): + def aggregate(self, representation_list: List[OrderedDict[str, Tensor]], **kwargs: Any) -> OrderedDict[str, Tensor]: """ Aggregate the knowledge from the users """ raise NotImplementedError - def test(self, dataset, **kwargs): + def test(self, **kwargs: Any) -> List[float]: """ Test the model on the server """ diff --git a/src/algos/fl.py b/src/algos/fl.py index 036f632..94eb368 100644 --- a/src/algos/fl.py +++ b/src/algos/fl.py @@ -45,17 +45,17 @@ def local_train(self, round: int, **kwargs: Any): self.client_log_utils.log_tb(f"train_loss/client{self.node_id}", avg_loss, round) self.client_log_utils.log_tb(f"train_accuracy/client{self.node_id}", avg_accuracy, round) - def local_test(self, **kwargs): + def local_test(self, **kwargs: Any): """ Test the model locally, not to be used in the traditional FedAvg """ pass - def get_representation(self) -> OrderedDict[str, Tensor]: + def get_representation(self, **kwargs: Any) -> OrderedDict[str, Tensor]: """ Share the model weights """ - return self.model.state_dict() + return self.model.state_dict() # type: ignore def set_representation(self, representation: OrderedDict[str, Tensor]): """ @@ -70,11 +70,11 @@ def run_protocol(self): self.local_train(round) self.local_test() repr = self.get_representation() - # self.client_log_utils.log_summary("Client {} sending done signal to {}".format(self.node_id, self.server_node)) + self.client_log_utils.log_summary("Client {} sending done signal to {}".format(self.node_id, self.server_node)) self.comm_utils.send(self.server_node, repr) - # self.client_log_utils.log_summary("Client {} waiting to get new model from {}".format(self.node_id, self.server_node)) + self.client_log_utils.log_summary("Client {} waiting to get new model from {}".format(self.node_id, self.server_node)) repr = self.comm_utils.receive(self.server_node) - # self.client_log_utils.log_summary("Client {} received new model from {}".format(self.node_id, self.server_node)) + self.client_log_utils.log_summary("Client {} received new model from {}".format(self.node_id, self.server_node)) self.set_representation(repr) # self.client_log_utils.log_summary("Round {} done for Client {}".format(round, self.node_id)) @@ -90,32 +90,48 @@ def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> ) self.folder_deletion_signal = config["folder_deletion_signal_path"] + # def fed_avg(self, model_wts: List[OrderedDict[str, Tensor]]): + # # All models are sampled currently at every round + # # Each model is assumed to have equal amount of data and hence + # # coeff is same for everyone + # num_users = len(model_wts) + # coeff = 1 / num_users # this assumes each node has equal amount of data + # avgd_wts: OrderedDict[str, Tensor] = OrderedDict() + # first_model = model_wts[0] + + # for node_num in range(num_users): + # local_wts = model_wts[node_num] + # for key in first_model.keys(): + # if node_num == 0: + # avgd_wts[key] = coeff * local_wts[key].to('cpu') + # else: + # avgd_wts[key] += coeff * local_wts[key].to('cpu') + # # put the model back to the device + # for key in avgd_wts.keys(): + # avgd_wts[key] = avgd_wts[key].to(self.device) + # return avgd_wts + def fed_avg(self, model_wts: List[OrderedDict[str, Tensor]]): - # All models are sampled currently at every round - # Each model is assumed to have equal amount of data and hence - # coeff is same for everyone num_users = len(model_wts) coeff = 1 / num_users - avgd_wts = OrderedDict() - first_model = model_wts[0] - - for node_num in range(num_users): - local_wts = model_wts[node_num] - for key in first_model.keys(): - if node_num == 0: - avgd_wts[key] = coeff * local_wts[key].to(self.device) - else: - avgd_wts[key] += coeff * local_wts[key].to(self.device) + avgd_wts: OrderedDict[str, Tensor] = OrderedDict() + + for key in model_wts[0].keys(): + avgd_wts[key] = sum(coeff * m[key] for m in model_wts) # type: ignore + + # Move to GPU only after averaging + for key in avgd_wts.keys(): + avgd_wts[key] = avgd_wts[key].to(self.device) return avgd_wts - def aggregate(self, representation_list: List[OrderedDict[str, Tensor]]): + def aggregate(self, representation_list: List[OrderedDict[str, Tensor]], **kwargs: Any) -> OrderedDict[str, Tensor]: """ Aggregate the model weights """ avg_wts = self.fed_avg(representation_list) return avg_wts - def set_representation(self, representation): + def set_representation(self, representation: OrderedDict[str, Tensor]): """ Set the model """ @@ -123,7 +139,7 @@ def set_representation(self, representation): print("braodcasted") self.model.load_state_dict(representation) - def test(self) -> float: + def test(self, **kwargs: Any) -> List[float]: """ Test the model on the server """ @@ -138,12 +154,13 @@ def test(self) -> float: if test_acc > self.best_acc: self.best_acc = test_acc self.model_utils.save_model(self.model, self.model_save_path) - return test_loss, test_acc, time_taken + return [test_loss, test_acc, time_taken] def single_round(self): """ Runs the whole training procedure """ + # calculate how much memory torch is occupying right now # self.log_utils.log_console("Server waiting for all clients to finish") reprs = self.comm_utils.all_gather() # self.log_utils.log_console("Server received all clients done signal") @@ -166,6 +183,6 @@ def run_protocol(self): self.log_utils.log_tb(f"test_acc/clients", acc, round) self.log_utils.log_tb(f"test_loss/clients", loss, round) self.log_utils.log_console("Round: {} test_acc:{:.4f}, test_loss:{:.4f}, time taken {:.2f} seconds".format(round, acc, loss, time_taken)) - self.log_utils.log_summary("Round: {} test_acc:{:.4f}, test_loss:{:.4f}, time taken {:.2f} seconds".format(round, acc, loss, time_taken)) + # self.log_utils.log_summary("Round: {} test_acc:{:.4f}, test_loss:{:.4f}, time taken {:.2f} seconds".format(round, acc, loss, time_taken)) self.log_utils.log_console("Round {} complete".format(round)) self.log_utils.log_summary("Round {} complete".format(round,)) diff --git a/src/resnet_in.py b/src/resnet_in.py index b160e3e..dca143b 100644 --- a/src/resnet_in.py +++ b/src/resnet_in.py @@ -5,6 +5,7 @@ This module implements ResNet models for ImageNet classification. """ +from typing import Any from torch import nn from torch.hub import load_state_dict_from_url @@ -297,7 +298,7 @@ def _resnet( return model -def resnet18(pretrained=False, progress=True, **kwargs): +def resnet18(pretrained:bool=False, progress:bool=True, **kwargs: Any): r"""ResNet-18 model from `"Deep Residual Learning for Image Recognition" `_ Args: @@ -309,7 +310,7 @@ def resnet18(pretrained=False, progress=True, **kwargs): -def resnet34(pretrained=False, progress=True, **kwargs): +def resnet34(pretrained:bool=False, progress: bool=True, **kwargs: Any): r"""ResNet-34 model from `"Deep Residual Learning for Image Recognition" `_ Args: diff --git a/src/utils/communication/grpc/grpc_utils.py b/src/utils/communication/grpc/grpc_utils.py index a97074b..64d8e6b 100644 --- a/src/utils/communication/grpc/grpc_utils.py +++ b/src/utils/communication/grpc/grpc_utils.py @@ -3,6 +3,9 @@ import torch def serialize_model(state_dict: OrderedDict[str, torch.Tensor]) -> bytes: + # put every parameter on cpu first + for key in state_dict.keys(): + state_dict[key] = state_dict[key].to('cpu') buffer = io.BytesIO() torch.save(state_dict, buffer) # type: ignore buffer.seek(0) diff --git a/src/utils/communication/grpc/main.py b/src/utils/communication/grpc/main.py index 72efb43..82695b1 100644 --- a/src/utils/communication/grpc/main.py +++ b/src/utils/communication/grpc/main.py @@ -153,7 +153,7 @@ def get_registered_users(peer_ids: OrderedDict[int, Dict[str, int|str]]) -> int: def register(self): with grpc.insecure_channel(self.super_node_host) as channel: # type: ignore stub = comm_pb2_grpc.CommunicationServerStub(channel) - max_tries = 5 + max_tries = 10 while max_tries > 0: try: self.rank = stub.get_rank(comm_pb2.Empty()).rank # type: ignore @@ -203,10 +203,12 @@ def initialize(self): sys.exit(1) else: quorum_threshold = self.num_users + 1 # +1 for the super node - while self.get_registered_users(self.servicer.peer_ids) < quorum_threshold: + num_registered = self.get_registered_users(self.servicer.peer_ids) + while num_registered < quorum_threshold: # sleep for 5 seconds - print(f"Waiting for {quorum_threshold} users to register") + print(f"Waiting for {quorum_threshold} users to register, {num_registered} have registered so far") time.sleep(5) + num_registered = self.get_registered_users(self.servicer.peer_ids) # TODO: Implement a timeout here and if the timeout is reached # then set quorum to False for all registered users # and exit the program diff --git a/src/utils/data_utils.py b/src/utils/data_utils.py index 11efa53..fa6e69c 100644 --- a/src/utils/data_utils.py +++ b/src/utils/data_utils.py @@ -1,4 +1,5 @@ import importlib +from typing import Tuple import numpy as np import torch import torchvision.transforms as T @@ -92,7 +93,7 @@ def filter_by_class(dataset, classes): return Subset(dataset, indices), indices -def random_samples(dataset, num_samples): +def random_samples(dataset, num_samples) -> Tuple[Subset, np.ndarray]: """ Returns a random subset of samples from the dataset. """ diff --git a/src/utils/log_utils.py b/src/utils/log_utils.py index 6392e8f..86fcdfb 100644 --- a/src/utils/log_utils.py +++ b/src/utils/log_utils.py @@ -9,11 +9,11 @@ import sys from glob import glob from shutil import copytree, copy2 -from PIL import Image +from typing import Any, Dict import torch -import torchvision.transforms as T -from torchvision.utils import make_grid, save_image -from tensorboardX import SummaryWriter +import torchvision.transforms as T # type: ignore +from torchvision.utils import make_grid, save_image # type: ignore +from tensorboardX import SummaryWriter # type: ignore import numpy as np @@ -57,7 +57,7 @@ def check_and_create_path(path: str, folder_deletion_path: str|None=None): signal_file.write("new") -def copy_source_code(config: dict) -> None: +def copy_source_code(config: Dict[str, Any]) -> None: """ Copy source code to experiment folder for reproducibility. @@ -104,7 +104,7 @@ class LogUtils: """ Utility class for logging and saving experiment data. """ - def __init__(self, config) -> None: + def __init__(self, config: Dict[str, Any]) -> None: log_dir = config["log_path"] load_existing = config["load_existing"] log_format = ( @@ -128,7 +128,7 @@ def init_summary(self): """ self.summary_file = open(f"{self.log_dir}/summary.txt", "w", encoding="utf-8") - def init_tb(self, load_existing): + def init_tb(self, load_existing: bool): """ Initialize TensorBoard logging. @@ -148,7 +148,7 @@ def init_npy(self): if not os.path.exists(npy_path) or not os.path.isdir(npy_path): os.makedirs(npy_path) - def log_summary(self, text): + def log_summary(self, text: str): """ Add summary text to the summary file for logging. """ @@ -158,7 +158,7 @@ def log_summary(self, text): else: raise ValueError("Summary file is not initialized. Call init_summary() first.") - def log_image(self, imgs: torch.Tensor, key, iteration): + def log_image(self, imgs: torch.Tensor, key: str, iteration: int): """ Log image to file and TensorBoard. @@ -171,7 +171,7 @@ def log_image(self, imgs: torch.Tensor, key, iteration): save_image(grid_img, f"{self.log_dir}/{iteration}_{key}.png") self.writer.add_image(key, grid_img.numpy(), iteration) - def log_console(self, msg): + def log_console(self, msg: str): """ Log a message to the console. @@ -180,7 +180,7 @@ def log_console(self, msg): """ logging.info(msg) - def log_tb(self, key, value, iteration): + def log_tb(self, key: str, value: float|int, iteration: int): """ Log a scalar value to TensorBoard. @@ -189,9 +189,9 @@ def log_tb(self, key, value, iteration): value (float): Value to log. iteration (int): Current iteration number. """ - self.writer.add_scalar(key, value, iteration) + self.writer.add_scalar(key, value, iteration) # type: ignore - def log_npy(self, key, value): + def log_npy(self, key: str, value: np.ndarray): """ Save a numpy array to file. @@ -201,7 +201,7 @@ def log_npy(self, key, value): """ np.save(f"{self.log_dir}/npy/{key}.npy", value) - def log_max_stats_per_client(self, stats_per_client, round_step, metric): + def log_max_stats_per_client(self, stats_per_client: np.ndarray, round_step: int, metric: str): """ Log maximum statistics per client. diff --git a/src/utils/model_utils.py b/src/utils/model_utils.py index d2fd261..406cd2a 100644 --- a/src/utils/model_utils.py +++ b/src/utils/model_utils.py @@ -1,9 +1,10 @@ from collections import OrderedDict -from typing import Tuple, List +from typing import Any, Tuple, List import torch from torch import Tensor import torch.nn as nn import torch.nn.functional as F +from torch.utils.data import DataLoader from torch.nn.parallel import DataParallel import resnet @@ -11,8 +12,8 @@ class ModelUtils(): - def __init__(self) -> None: - pass + def __init__(self, device: torch.device) -> None: + self.device = device self.models_layers_idx = { "resnet10": { @@ -36,9 +37,9 @@ def get_model( model_name: str, dset: str, device: torch.device, - device_ids: list, - pretrained=False, - **kwargs) -> DataParallel: + device_ids: List[int], + pretrained:bool=False, + **kwargs: Any) -> nn.Module: # TODO: add support for loading checkpointed models model_name = model_name.lower() if model_name == "resnet10": @@ -61,12 +62,12 @@ def get_model( def train(self, model: nn.Module, - optim, - dloader, - loss_fn, + optim: torch.optim.Optimizer, + dloader: DataLoader[Any], + loss_fn: Any, device: torch.device, - test_loader=None, - **kwargs) -> Tuple[float, + test_loader: DataLoader[Any]|None=None, + **kwargs: Any) -> Tuple[float, float]: """TODO: generate docstring """ @@ -185,8 +186,8 @@ def deep_mutual_train(self, models, optim, dloader, train_accuracy[i] = 100. * correct[i] / len(dloader.dataset) return train_loss, train_accuracy - def test(self, model, dloader, loss_fn, device, - **kwargs) -> Tuple[float, float]: + def test(self, model: nn.Module, dloader: DataLoader[Any], loss_fn: Any, device: torch.device, + **kwargs: Any) -> Tuple[float, float]: """TODO: generate docstring """ model.eval() @@ -209,7 +210,7 @@ def test(self, model, dloader, loss_fn, device, acc = correct / len(dloader.dataset) return test_loss, acc - def save_model(self, model, path): + def save_model(self, model: nn.Module, path: str) -> None: if isinstance(model, DataParallel): model_ = model.module else: @@ -250,3 +251,9 @@ def filter_model_weights( if key not in key_to_ignore: filtered_model_wts[key] = param return filtered_model_wts + + def get_memory_usage(self): + """ + Get the memory usage + """ + return torch.cuda.memory_allocated(self.device)