Skip to content

Commit

Permalink
reduce memory footprint and improve type checking
Browse files Browse the repository at this point in the history
  • Loading branch information
tremblerz committed Aug 15, 2024
1 parent 1463894 commit 43d016e
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 70 deletions.
24 changes: 12 additions & 12 deletions src/algos/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch.utils.data import DataLoader, Subset

from collections import OrderedDict
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple
from torch import Tensor
import copy
import random
Expand All @@ -28,11 +28,11 @@
get_dset_balanced_communities,
get_dset_communities,
)
import torchvision.transforms as T
import torchvision.transforms as T # type: ignore
import os

class BaseNode(ABC):
def __init__(self, config, comm_utils: CommunicationManager) -> None:
def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None:
self.comm_utils = comm_utils
self.node_id = self.comm_utils.get_rank()

Expand All @@ -52,21 +52,21 @@ def __init__(self, config, comm_utils: CommunicationManager) -> None:
if isinstance(config["dset"], dict):
if self.node_id != 0:
config["dset"].pop("0")
self.dset = config["dset"][str(self.node_id)]
self.dset = str(config["dset"][str(self.node_id)])
config["dpath"] = config["dpath"][self.dset]
else:
self.dset = config["dset"]

self.setup_cuda(config)
self.model_utils = ModelUtils()
self.model_utils = ModelUtils(self.device)

self.dset_obj = get_dataset(self.dset, dpath=config["dpath"])
self.set_constants()

def set_constants(self):
self.best_acc = 0.0

def setup_cuda(self, config):
def setup_cuda(self, config: Dict[str, Any]):
# Need a mapping from rank to device id
device_ids_map = config["device_ids"]
node_name = "node_{}".format(self.node_id)
Expand All @@ -80,7 +80,7 @@ def setup_cuda(self, config):
self.device = torch.device("cpu")
print("Using CPU")

def set_model_parameters(self, config):
def set_model_parameters(self, config: Dict[str, Any]):
# Model related parameters
optim_name = config.get("optimizer", "adam")
if optim_name == "adam":
Expand Down Expand Up @@ -144,7 +144,7 @@ def set_shared_exp_parameters(self, config):
self.log_utils.log_console("Communities: {}".format(self.communities))

@abstractmethod
def run_protocol(self):
def run_protocol(self) -> None:
raise NotImplementedError


Expand Down Expand Up @@ -368,13 +368,13 @@ def local_train(self, round: int, **kwargs: Any) -> None:
"""
raise NotImplementedError

def local_test(self, dataset, **kwargs):
def local_test(self, **kwargs: Any) -> float | Tuple[float, float] | None:
"""
Test the model locally
"""
raise NotImplementedError

def get_representation(self, **kwargs):
def get_representation(self, **kwargs: Any) -> OrderedDict[str, Tensor] | List[Tensor] | Tensor:
"""
Share the model representation
"""
Expand Down Expand Up @@ -429,13 +429,13 @@ def set_data_parameters(self, config):
batch_size = config["batch_size"]
self._test_loader = DataLoader(test_dset, batch_size=batch_size)

def aggregate(self, representation_list, **kwargs):
def aggregate(self, representation_list: List[OrderedDict[str, Tensor]], **kwargs: Any) -> OrderedDict[str, Tensor]:
"""
Aggregate the knowledge from the users
"""
raise NotImplementedError

def test(self, dataset, **kwargs):
def test(self, **kwargs: Any) -> List[float]:
"""
Test the model on the server
"""
Expand Down
65 changes: 41 additions & 24 deletions src/algos/fl.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,17 @@ def local_train(self, round: int, **kwargs: Any):
self.client_log_utils.log_tb(f"train_loss/client{self.node_id}", avg_loss, round)
self.client_log_utils.log_tb(f"train_accuracy/client{self.node_id}", avg_accuracy, round)

def local_test(self, **kwargs):
def local_test(self, **kwargs: Any):
"""
Test the model locally, not to be used in the traditional FedAvg
"""
pass

def get_representation(self) -> OrderedDict[str, Tensor]:
def get_representation(self, **kwargs: Any) -> OrderedDict[str, Tensor]:
"""
Share the model weights
"""
return self.model.state_dict()
return self.model.state_dict() # type: ignore

def set_representation(self, representation: OrderedDict[str, Tensor]):
"""
Expand All @@ -70,11 +70,11 @@ def run_protocol(self):
self.local_train(round)
self.local_test()
repr = self.get_representation()
# self.client_log_utils.log_summary("Client {} sending done signal to {}".format(self.node_id, self.server_node))
self.client_log_utils.log_summary("Client {} sending done signal to {}".format(self.node_id, self.server_node))
self.comm_utils.send(self.server_node, repr)
# self.client_log_utils.log_summary("Client {} waiting to get new model from {}".format(self.node_id, self.server_node))
self.client_log_utils.log_summary("Client {} waiting to get new model from {}".format(self.node_id, self.server_node))
repr = self.comm_utils.receive(self.server_node)
# self.client_log_utils.log_summary("Client {} received new model from {}".format(self.node_id, self.server_node))
self.client_log_utils.log_summary("Client {} received new model from {}".format(self.node_id, self.server_node))
self.set_representation(repr)
# self.client_log_utils.log_summary("Round {} done for Client {}".format(round, self.node_id))

Expand All @@ -90,40 +90,56 @@ def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) ->
)
self.folder_deletion_signal = config["folder_deletion_signal_path"]

# def fed_avg(self, model_wts: List[OrderedDict[str, Tensor]]):
# # All models are sampled currently at every round
# # Each model is assumed to have equal amount of data and hence
# # coeff is same for everyone
# num_users = len(model_wts)
# coeff = 1 / num_users # this assumes each node has equal amount of data
# avgd_wts: OrderedDict[str, Tensor] = OrderedDict()
# first_model = model_wts[0]

# for node_num in range(num_users):
# local_wts = model_wts[node_num]
# for key in first_model.keys():
# if node_num == 0:
# avgd_wts[key] = coeff * local_wts[key].to('cpu')
# else:
# avgd_wts[key] += coeff * local_wts[key].to('cpu')
# # put the model back to the device
# for key in avgd_wts.keys():
# avgd_wts[key] = avgd_wts[key].to(self.device)
# return avgd_wts

def fed_avg(self, model_wts: List[OrderedDict[str, Tensor]]):
# All models are sampled currently at every round
# Each model is assumed to have equal amount of data and hence
# coeff is same for everyone
num_users = len(model_wts)
coeff = 1 / num_users
avgd_wts = OrderedDict()
first_model = model_wts[0]

for node_num in range(num_users):
local_wts = model_wts[node_num]
for key in first_model.keys():
if node_num == 0:
avgd_wts[key] = coeff * local_wts[key].to(self.device)
else:
avgd_wts[key] += coeff * local_wts[key].to(self.device)
avgd_wts: OrderedDict[str, Tensor] = OrderedDict()

for key in model_wts[0].keys():
avgd_wts[key] = sum(coeff * m[key] for m in model_wts) # type: ignore

# Move to GPU only after averaging
for key in avgd_wts.keys():
avgd_wts[key] = avgd_wts[key].to(self.device)
return avgd_wts

def aggregate(self, representation_list: List[OrderedDict[str, Tensor]]):
def aggregate(self, representation_list: List[OrderedDict[str, Tensor]], **kwargs: Any) -> OrderedDict[str, Tensor]:
"""
Aggregate the model weights
"""
avg_wts = self.fed_avg(representation_list)
return avg_wts

def set_representation(self, representation):
def set_representation(self, representation: OrderedDict[str, Tensor]):
"""
Set the model
"""
self.comm_utils.broadcast(representation)
print("braodcasted")
self.model.load_state_dict(representation)

def test(self) -> float:
def test(self, **kwargs: Any) -> List[float]:
"""
Test the model on the server
"""
Expand All @@ -138,12 +154,13 @@ def test(self) -> float:
if test_acc > self.best_acc:
self.best_acc = test_acc
self.model_utils.save_model(self.model, self.model_save_path)
return test_loss, test_acc, time_taken
return [test_loss, test_acc, time_taken]

def single_round(self):
"""
Runs the whole training procedure
"""
# calculate how much memory torch is occupying right now
# self.log_utils.log_console("Server waiting for all clients to finish")
reprs = self.comm_utils.all_gather()
# self.log_utils.log_console("Server received all clients done signal")
Expand All @@ -166,6 +183,6 @@ def run_protocol(self):
self.log_utils.log_tb(f"test_acc/clients", acc, round)
self.log_utils.log_tb(f"test_loss/clients", loss, round)
self.log_utils.log_console("Round: {} test_acc:{:.4f}, test_loss:{:.4f}, time taken {:.2f} seconds".format(round, acc, loss, time_taken))
self.log_utils.log_summary("Round: {} test_acc:{:.4f}, test_loss:{:.4f}, time taken {:.2f} seconds".format(round, acc, loss, time_taken))
# self.log_utils.log_summary("Round: {} test_acc:{:.4f}, test_loss:{:.4f}, time taken {:.2f} seconds".format(round, acc, loss, time_taken))
self.log_utils.log_console("Round {} complete".format(round))
self.log_utils.log_summary("Round {} complete".format(round,))
5 changes: 3 additions & 2 deletions src/resnet_in.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
This module implements ResNet models for ImageNet classification.
"""

from typing import Any
from torch import nn
from torch.hub import load_state_dict_from_url

Expand Down Expand Up @@ -297,7 +298,7 @@ def _resnet(
return model


def resnet18(pretrained=False, progress=True, **kwargs):
def resnet18(pretrained:bool=False, progress:bool=True, **kwargs: Any):
r"""ResNet-18 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
Expand All @@ -309,7 +310,7 @@ def resnet18(pretrained=False, progress=True, **kwargs):



def resnet34(pretrained=False, progress=True, **kwargs):
def resnet34(pretrained:bool=False, progress: bool=True, **kwargs: Any):
r"""ResNet-34 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
Expand Down
3 changes: 3 additions & 0 deletions src/utils/communication/grpc/grpc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import torch

def serialize_model(state_dict: OrderedDict[str, torch.Tensor]) -> bytes:
# put every parameter on cpu first
for key in state_dict.keys():
state_dict[key] = state_dict[key].to('cpu')
buffer = io.BytesIO()
torch.save(state_dict, buffer) # type: ignore
buffer.seek(0)
Expand Down
8 changes: 5 additions & 3 deletions src/utils/communication/grpc/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def get_registered_users(peer_ids: OrderedDict[int, Dict[str, int|str]]) -> int:
def register(self):
with grpc.insecure_channel(self.super_node_host) as channel: # type: ignore
stub = comm_pb2_grpc.CommunicationServerStub(channel)
max_tries = 5
max_tries = 10
while max_tries > 0:
try:
self.rank = stub.get_rank(comm_pb2.Empty()).rank # type: ignore
Expand Down Expand Up @@ -203,10 +203,12 @@ def initialize(self):
sys.exit(1)
else:
quorum_threshold = self.num_users + 1 # +1 for the super node
while self.get_registered_users(self.servicer.peer_ids) < quorum_threshold:
num_registered = self.get_registered_users(self.servicer.peer_ids)
while num_registered < quorum_threshold:
# sleep for 5 seconds
print(f"Waiting for {quorum_threshold} users to register")
print(f"Waiting for {quorum_threshold} users to register, {num_registered} have registered so far")
time.sleep(5)
num_registered = self.get_registered_users(self.servicer.peer_ids)
# TODO: Implement a timeout here and if the timeout is reached
# then set quorum to False for all registered users
# and exit the program
Expand Down
3 changes: 2 additions & 1 deletion src/utils/data_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import importlib
from typing import Tuple
import numpy as np
import torch
import torchvision.transforms as T
Expand Down Expand Up @@ -92,7 +93,7 @@ def filter_by_class(dataset, classes):
return Subset(dataset, indices), indices


def random_samples(dataset, num_samples):
def random_samples(dataset, num_samples) -> Tuple[Subset, np.ndarray]:
"""
Returns a random subset of samples from the dataset.
"""
Expand Down
Loading

0 comments on commit 43d016e

Please sign in to comment.