diff --git a/src/algos/base_class.py b/src/algos/base_class.py index 894d66e..9094b7f 100644 --- a/src/algos/base_class.py +++ b/src/algos/base_class.py @@ -402,7 +402,7 @@ def push(self, neighbors: int | List[int]) -> None: def calculate_cpu_tensor_memory(self) -> int: total_memory = 0 for obj in gc.get_objects(): - if torch.is_tensor(obj) and obj.device.type == 'cpu': + if torch.is_tensor(obj) and obj.device.type == 'cpu': # type: ignore total_memory += obj.element_size() * obj.nelement() return total_memory @@ -410,8 +410,10 @@ def get_memory_metrics(self) -> Tuple[float | int, float | int]: """ Get memory metrics """ - peak_dram = self.calculate_cpu_tensor_memory() if self.log_memory else 0 - peak_gpu = torch.cuda.max_memory_allocated() if torch.cuda.is_available() and self.log_memory else 0 + peak_dram, peak_gpu = 0, 0 + if self.log_memory: + peak_dram = self.calculate_cpu_tensor_memory() + peak_gpu = int(torch.cuda.max_memory_allocated()) # type: ignore return peak_dram, peak_gpu class BaseClient(BaseNode): diff --git a/src/utils/communication/grpc/main.py b/src/utils/communication/grpc/main.py index ca137c3..7fb687a 100644 --- a/src/utils/communication/grpc/main.py +++ b/src/utils/communication/grpc/main.py @@ -465,7 +465,7 @@ def receive_pushed(self) -> List[OrderedDict[str, Any]]: for _ in range(self.servicer.received_data.qsize()): item = self.servicer.received_data.get() round = item.get("round", 0) - if round <= self_round: + if (not self.synchronous) or round <= self_round: items.append(item) else: self.servicer.received_data.put(item)