Skip to content

Commit

Permalink
fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
tocean committed Jan 11, 2024
1 parent a8ff215 commit d3acfa2
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions msamp/fsdp/_runtime_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
def _fp8_post_backward_hook(state, handle, *unused):
"""A post-backward communication hook which supports fp8."""
accumulate_grad = hasattr(state._flat_param, '_saved_grad_shard')
if accumulate_grad and not torch.all(state._flat_param._saved_grad_shard == 0):
raise NotImplementedError('accumulate_grad is not supported for fp8')
if accumulate_grad and torch.count_nonzero(state._flat_param._saved_grad_shard).item() > 0:
raise NotImplementedError('accumulate_grad is not supported yet for fp8')

old_communication_hook = state._communication_hook
state._communication_hook = state._get_fp8_comm_hook()
Expand Down

0 comments on commit d3acfa2

Please sign in to comment.