From d3acfa24b10d0ac8835a228dd66e972a6f2d920c Mon Sep 17 00:00:00 2001 From: tocean Date: Thu, 11 Jan 2024 02:02:41 +0000 Subject: [PATCH] fix comments --- msamp/fsdp/_runtime_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/msamp/fsdp/_runtime_utils.py b/msamp/fsdp/_runtime_utils.py index 566590fd..0e8de20d 100644 --- a/msamp/fsdp/_runtime_utils.py +++ b/msamp/fsdp/_runtime_utils.py @@ -15,8 +15,8 @@ def _fp8_post_backward_hook(state, handle, *unused): """A post-backward communication hook which supports fp8.""" accumulate_grad = hasattr(state._flat_param, '_saved_grad_shard') - if accumulate_grad and not torch.all(state._flat_param._saved_grad_shard == 0): - raise NotImplementedError('accumulate_grad is not supported for fp8') + if accumulate_grad and torch.count_nonzero(state._flat_param._saved_grad_shard).item() > 0: + raise NotImplementedError('accumulate_grad is not supported yet for fp8') old_communication_hook = state._communication_hook state._communication_hook = state._get_fp8_comm_hook()