diff --git a/src/algos/base_class.py b/src/algos/base_class.py index ca6cba16..ce39a99b 100644 --- a/src/algos/base_class.py +++ b/src/algos/base_class.py @@ -14,6 +14,7 @@ import random import time import torch.utils.data +import gc from utils.communication.comm_utils import CommunicationManager from utils.plot_utils import PlotUtils @@ -118,6 +119,8 @@ def __init__( dropout_rng = random.Random(dropout_seed) self.dropout = NodeDropout(self.node_id, config["dropout_dicts"], dropout_rng) + self.log_memory = config.get("log_memory", False) + def set_constants(self) -> None: """Add docstring here""" self.best_acc = 0.0 @@ -367,6 +370,24 @@ def push(self, neighbors: List[int]) -> None: self.comm_utils.send(neighbors, data_to_send) + def calculate_cpu_tensor_memory(self) -> int: + total_memory = 0 + for obj in gc.get_objects(): + if torch.is_tensor(obj) and obj.device.type == 'cpu': + total_memory += obj.element_size() * obj.nelement() + return total_memory + + def get_memory_metrics(self) -> Dict[str, Any]: + """ + Get memory metrics + """ + if self.log_memory: + return { + "peak_dram": self.calculate_cpu_tensor_memory(), + "peak_gpu": torch.cuda.max_memory_allocated(), + } + return {} + class BaseClient(BaseNode): """ Abstract class for all algorithms diff --git a/src/algos/fl.py b/src/algos/fl.py index a45a65a1..f1558cf7 100644 --- a/src/algos/fl.py +++ b/src/algos/fl.py @@ -83,7 +83,8 @@ def run_protocol(self): self.receive_and_aggregate() stats["bytes_received"], stats["bytes_sent"] = self.comm_utils.get_comm_cost() - + stats.update(self.get_memory_metrics()) + self.log_metrics(stats=stats, iteration=round) @@ -161,7 +162,7 @@ def single_round(self): """ Runs the whole training procedure """ - self.receive_and_aggregate() + self.receive_and_aggregate() def run_protocol(self): stats: Dict[str, Any] = {} @@ -173,4 +174,5 @@ def run_protocol(self): self.single_round() stats["bytes_received"], stats["bytes_sent"] = self.comm_utils.get_comm_cost() stats["test_loss"], stats["test_acc"], stats["test_time"] = self.test() + stats.update(self.get_memory_metrics()) self.log_metrics(stats=stats, iteration=round) diff --git a/src/algos/fl_push.py b/src/algos/fl_push.py index 3180c9e8..54fd46f3 100644 --- a/src/algos/fl_push.py +++ b/src/algos/fl_push.py @@ -35,6 +35,7 @@ def run_protocol(self): self.push(self.server_node) stats["bytes_received"], stats["bytes_sent"] = self.comm_utils.get_comm_cost() + stats.update(self.get_memory_metrics()) self.log_metrics(stats=stats, iteration=round) @@ -68,5 +69,6 @@ def run_protocol(self): self.single_round() stats["bytes_received"], stats["bytes_sent"] = self.comm_utils.get_comm_cost() stats["test_loss"], stats["test_acc"], stats["test_time"] = self.test() + stats.update(self.get_memory_metrics()) self.log_metrics(stats=stats, iteration=round) self.local_round_done() diff --git a/src/algos/fl_static.py b/src/algos/fl_static.py index af0d764a..c9beede3 100644 --- a/src/algos/fl_static.py +++ b/src/algos/fl_static.py @@ -60,6 +60,7 @@ def run_protocol(self) -> None: # evaluate the model on the test data # Inside FedStaticNode.run_protocol() stats["test_loss"], stats["test_acc"] = self.local_test() + stats.update(self.log_memory()) self.log_metrics(stats=stats, iteration=it) diff --git a/src/algos/swift.py b/src/algos/swift.py index 0d384197..196c7241 100644 --- a/src/algos/swift.py +++ b/src/algos/swift.py @@ -53,6 +53,9 @@ def run_protocol(self) -> None: # evaluate the model on the test data # Inside FedStaticNode.run_protocol() stats["test_loss"], stats["test_acc"] = self.local_test() + + stats.update(self.get_memory_metrics()) + self.log_metrics(stats=stats, iteration=it) self.local_round_done() diff --git a/src/configs/sys_config.py b/src/configs/sys_config.py index 4aa59532..7041813c 100644 --- a/src/configs/sys_config.py +++ b/src/configs/sys_config.py @@ -351,6 +351,7 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): "test_label_distribution": "iid", "exp_keys": [], "dropout_dicts": dropout_dicts, + "log_memory": True, } current_config = grpc_system_config