Skip to content

Commit

Permalink
Add memory logging (#127)
Browse files Browse the repository at this point in the history
  • Loading branch information
rishi-s8 authored Oct 28, 2024
1 parent 277d1a9 commit 28b8291
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 2 deletions.
21 changes: 21 additions & 0 deletions src/algos/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import random
import time
import torch.utils.data
import gc

from utils.communication.comm_utils import CommunicationManager
from utils.plot_utils import PlotUtils
Expand Down Expand Up @@ -118,6 +119,8 @@ def __init__(
dropout_rng = random.Random(dropout_seed)
self.dropout = NodeDropout(self.node_id, config["dropout_dicts"], dropout_rng)

self.log_memory = config.get("log_memory", False)

def set_constants(self) -> None:
"""Add docstring here"""
self.best_acc = 0.0
Expand Down Expand Up @@ -367,6 +370,24 @@ def push(self, neighbors: List[int]) -> None:

self.comm_utils.send(neighbors, data_to_send)

def calculate_cpu_tensor_memory(self) -> int:
total_memory = 0
for obj in gc.get_objects():
if torch.is_tensor(obj) and obj.device.type == 'cpu':
total_memory += obj.element_size() * obj.nelement()
return total_memory

def get_memory_metrics(self) -> Dict[str, Any]:
"""
Get memory metrics
"""
if self.log_memory:
return {
"peak_dram": self.calculate_cpu_tensor_memory(),
"peak_gpu": torch.cuda.max_memory_allocated(),
}
return {}

class BaseClient(BaseNode):
"""
Abstract class for all algorithms
Expand Down
6 changes: 4 additions & 2 deletions src/algos/fl.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def run_protocol(self):
self.receive_and_aggregate()

stats["bytes_received"], stats["bytes_sent"] = self.comm_utils.get_comm_cost()

stats.update(self.get_memory_metrics())

self.log_metrics(stats=stats, iteration=round)


Expand Down Expand Up @@ -161,7 +162,7 @@ def single_round(self):
"""
Runs the whole training procedure
"""
self.receive_and_aggregate()
self.receive_and_aggregate()

def run_protocol(self):
stats: Dict[str, Any] = {}
Expand All @@ -173,4 +174,5 @@ def run_protocol(self):
self.single_round()
stats["bytes_received"], stats["bytes_sent"] = self.comm_utils.get_comm_cost()
stats["test_loss"], stats["test_acc"], stats["test_time"] = self.test()
stats.update(self.get_memory_metrics())
self.log_metrics(stats=stats, iteration=round)
2 changes: 2 additions & 0 deletions src/algos/fl_push.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def run_protocol(self):
self.push(self.server_node)

stats["bytes_received"], stats["bytes_sent"] = self.comm_utils.get_comm_cost()
stats.update(self.get_memory_metrics())

self.log_metrics(stats=stats, iteration=round)

Expand Down Expand Up @@ -68,5 +69,6 @@ def run_protocol(self):
self.single_round()
stats["bytes_received"], stats["bytes_sent"] = self.comm_utils.get_comm_cost()
stats["test_loss"], stats["test_acc"], stats["test_time"] = self.test()
stats.update(self.get_memory_metrics())
self.log_metrics(stats=stats, iteration=round)
self.local_round_done()
1 change: 1 addition & 0 deletions src/algos/fl_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def run_protocol(self) -> None:
# evaluate the model on the test data
# Inside FedStaticNode.run_protocol()
stats["test_loss"], stats["test_acc"] = self.local_test()
stats.update(self.log_memory())
self.log_metrics(stats=stats, iteration=it)


Expand Down
3 changes: 3 additions & 0 deletions src/algos/swift.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ def run_protocol(self) -> None:
# evaluate the model on the test data
# Inside FedStaticNode.run_protocol()
stats["test_loss"], stats["test_acc"] = self.local_test()

stats.update(self.get_memory_metrics())

self.log_metrics(stats=stats, iteration=it)
self.local_round_done()

Expand Down
1 change: 1 addition & 0 deletions src/configs/sys_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE):
"test_label_distribution": "iid",
"exp_keys": [],
"dropout_dicts": dropout_dicts,
"log_memory": True,
}

current_config = grpc_system_config
Expand Down

0 comments on commit 28b8291

Please sign in to comment.